sac#

Soft Actor-Critic (SAC)

Classes#

SACAgent

Soft Actor-Critic (SAC) agent.

SACConfig

Hyperparameter configuration for the SAC agent.

SACPolicy

Soft Actor-Critic (SAC) policy class.

class prt_rl.model_free.sac.SACAgent(policy: SACPolicy, config: SACConfig, *, device: str = 'cpu')[source]#

Soft Actor-Critic (SAC) agent.

Parameters:
  • policy (SACPolicy | None) – Policy to use. If None, a default SACPolicy will be created.

  • config (SACConfig) – Configuration for the SAC agent.

  • device (str) – Device to run the model on (e.g., ‘cpu’ or ‘cuda’).

References

[1] https://arxiv.org/pdf/1812.05905

act(obs: Tensor, deterministic: bool = False) Tensor[source]#

Perform an action based on the current state.

Parameters:
  • obs (torch.Tensor) – The current observation from the environment.

  • deterministic (bool) – If True, the agent will select actions deterministically.

Returns:

The action to be taken.

Return type:

torch.Tensor

classmethod load(path: str | Path, map_location: str | device = 'cpu') SACAgent[source]#

Loads the checkpoint and returns a fully-constructed SACAgent.

train(env: EnvironmentInterface, total_steps: int, schedulers: List[ParameterScheduler] | None = None, logger: Logger | None = None, evaluator: Evaluator | None = None, show_progress: bool = True) None[source]#

Train the SAC agent.

Parameters:
  • env (EnvironmentInterface) – The environment to train on.

  • total_steps (int) – Total number of environment steps to train for.

  • schedulers (List[ParameterScheduler] | None) – List of parameter schedulers to update during training.

  • logger (Logger | None) – Logger for logging training metrics. If None, a default logger will be created.

  • evaluator (Evaluator | None) – Evaluator for periodic evaluation during training.

  • show_progress (bool) – If True, display a progress bar during training.

class prt_rl.model_free.sac.SACConfig(target_entropy: float, buffer_size: int = 1000000, min_buffer_size: int = 100, steps_per_batch: int = 1, mini_batch_size: int = 256, gradient_steps: int = 1, learning_rate: float = 0.0003, tau: float = 0.005, gamma: float = 0.99, entropy_coeff: float | None = None, use_log_entropy: bool = True, reward_scale: float = 1.0)[source]#

Hyperparameter configuration for the SAC agent.

Parameters:
  • buffer_size (int) – Size of the replay buffer.

  • min_buffer_size (int) – Minimum number of transitions in the replay buffer before training starts.

  • steps_per_batch (int) – Number of steps to collect per training batch.

  • mini_batch_size (int) – Size of the mini-batch sampled from the replay buffer for training.

  • gradient_steps (int) – Number of gradient update steps to perform after each batch of experience is collected.

  • learning_rate (float) – Learning rate for the optimizers.

  • tau (float) – Soft update coefficient for the target networks.

  • gamma (float) – Discount factor for future rewards.

  • entropy_coeff (float | None) – Initial value for the entropy coefficient, alpha. If None, it will be learned.

  • target_entropy (float | None) – Target entropy for the policy. A reasonable default is -action_dim.

  • use_log_entropy (bool) – If True, optimize the log of the entropy coefficient, else optimize the coefficient directly.

  • reward_scale (float) – Scaling factor for rewards.

class prt_rl.model_free.sac.SACPolicy(network: Module, actor_head: DistributionHead, critic_head: QValueHead, *, action_min: Tensor, action_max: Tensor, num_critics: int = 2, critic_network: Module | None = None)[source]#

Soft Actor-Critic (SAC) policy class.

The default actor is a DistributionPolicy with a TanhGaussian distribution, and the default critic is a StateActionCritic with 2 critics.

Parameters:
  • env_params (EnvParams) – Environment parameters.

  • num_critics (int) – Number of critics to use in the SAC algorithm.

  • actor (DistributionPolicy | None) – Actor policy. If None, a default DistributionPolicy will be created.

  • critic (StateActionCritic | None) – Critic network. If None, a default StateActionCritic will be created.

  • device (str) – Device to run the model on (e.g., ‘cpu’ or ‘cuda’).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#

Predict the action based on the current state.

Parameters:
  • state (torch.Tensor) – Current state tensor.

  • deterministic (bool) – If True, choose the action deterministically. Default is False.

Returns:

A tuple containing the chosen action, value estimate, and action log probability.
  • action (torch.Tensor): Tensor with the chosen action. Shape (B, action_dim)

  • log_prob (torch.Tensor): None

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

get_q_values(obs: Tensor, action: Tensor, index: int | None = None) Tensor[source]#

Get Q-values from all critics for the given state-action pairs.

Parameters:
Returns:

Tensor containing Q-values from all critics. Shape (B, C, 1) where C is the number of critics.

Return type:

torch.Tensor

get_target_q_values(obs: Tensor, action: Tensor) Tensor[source]#

Get target Q-values from all target critics for the given state-action pairs.

Parameters:
Returns:

Tensor containing target Q-values from all critics. Shape (B, C, 1) where C is the number of critics.

Return type:

torch.Tensor

metadata()[source]#

Optionally save metadata alongside the policy. This is a no-op in the base class but can be overridden by subclasses.