sac#
Soft Actor-Critic (SAC)
Classes#
- class prt_rl.model_free.sac.SACAgent(policy: SACPolicy, config: SACConfig, *, device: str = 'cpu')[source]#
Soft Actor-Critic (SAC) agent.
- Parameters:
References
[1] https://arxiv.org/pdf/1812.05905
- act(obs: Tensor, deterministic: bool = False) Tensor[source]#
Perform an action based on the current state.
- Parameters:
obs (torch.Tensor) – The current observation from the environment.
deterministic (bool) – If True, the agent will select actions deterministically.
- Returns:
The action to be taken.
- Return type:
- classmethod load(path: str | Path, map_location: str | device = 'cpu') SACAgent[source]#
Loads the checkpoint and returns a fully-constructed SACAgent.
- train(env: EnvironmentInterface, total_steps: int, schedulers: List[ParameterScheduler] | None = None, logger: Logger | None = None, evaluator: Evaluator | None = None, show_progress: bool = True) None[source]#
Train the SAC agent.
- Parameters:
env (EnvironmentInterface) – The environment to train on.
total_steps (int) – Total number of environment steps to train for.
schedulers (List[ParameterScheduler] | None) – List of parameter schedulers to update during training.
logger (Logger | None) – Logger for logging training metrics. If None, a default logger will be created.
evaluator (Evaluator | None) – Evaluator for periodic evaluation during training.
show_progress (bool) – If True, display a progress bar during training.
- class prt_rl.model_free.sac.SACConfig(target_entropy: float, buffer_size: int = 1000000, min_buffer_size: int = 100, steps_per_batch: int = 1, mini_batch_size: int = 256, gradient_steps: int = 1, learning_rate: float = 0.0003, tau: float = 0.005, gamma: float = 0.99, entropy_coeff: float | None = None, use_log_entropy: bool = True, reward_scale: float = 1.0)[source]#
Hyperparameter configuration for the SAC agent.
- Parameters:
buffer_size (int) – Size of the replay buffer.
min_buffer_size (int) – Minimum number of transitions in the replay buffer before training starts.
steps_per_batch (int) – Number of steps to collect per training batch.
mini_batch_size (int) – Size of the mini-batch sampled from the replay buffer for training.
gradient_steps (int) – Number of gradient update steps to perform after each batch of experience is collected.
learning_rate (float) – Learning rate for the optimizers.
tau (float) – Soft update coefficient for the target networks.
gamma (float) – Discount factor for future rewards.
entropy_coeff (float | None) – Initial value for the entropy coefficient, alpha. If None, it will be learned.
target_entropy (float | None) – Target entropy for the policy. A reasonable default is -action_dim.
use_log_entropy (bool) – If True, optimize the log of the entropy coefficient, else optimize the coefficient directly.
reward_scale (float) – Scaling factor for rewards.
- class prt_rl.model_free.sac.SACPolicy(network: Module, actor_head: DistributionHead, critic_head: QValueHead, *, action_min: Tensor, action_max: Tensor, num_critics: int = 2, critic_network: Module | None = None)[source]#
Soft Actor-Critic (SAC) policy class.
The default actor is a DistributionPolicy with a TanhGaussian distribution, and the default critic is a StateActionCritic with 2 critics.
- Parameters:
env_params (EnvParams) – Environment parameters.
num_critics (int) – Number of critics to use in the SAC algorithm.
actor (DistributionPolicy | None) – Actor policy. If None, a default DistributionPolicy will be created.
critic (StateActionCritic | None) – Critic network. If None, a default StateActionCritic will be created.
device (str) – Device to run the model on (e.g., ‘cpu’ or ‘cuda’).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#
Predict the action based on the current state.
- Parameters:
state (torch.Tensor) – Current state tensor.
deterministic (bool) – If True, choose the action deterministically. Default is False.
- Returns:
- A tuple containing the chosen action, value estimate, and action log probability.
action (torch.Tensor): Tensor with the chosen action. Shape (B, action_dim)
log_prob (torch.Tensor): None
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- get_q_values(obs: Tensor, action: Tensor, index: int | None = None) Tensor[source]#
Get Q-values from all critics for the given state-action pairs.
- Parameters:
obs (torch.Tensor) – Current observation tensor.
action (torch.Tensor) – Action tensor.
- Returns:
Tensor containing Q-values from all critics. Shape (B, C, 1) where C is the number of critics.
- Return type:
- get_target_q_values(obs: Tensor, action: Tensor) Tensor[source]#
Get target Q-values from all target critics for the given state-action pairs.
- Parameters:
obs (torch.Tensor) – Current observation tensor.
action (torch.Tensor) – Action tensor.
- Returns:
Tensor containing target Q-values from all critics. Shape (B, C, 1) where C is the number of critics.
- Return type: