"""
Soft Actor-Critic (SAC)
"""
from dataclasses import dataclass, asdict
from pathlib import Path
import itertools
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from typing import Optional, List, Tuple, Dict
from prt_rl.agent import Agent
from prt_rl.env.interface import EnvironmentInterface
from prt_rl.common.policies import NeuralPolicy, RandomPolicy
from prt_rl.common.loggers import Logger
from prt_rl.common.schedulers import ParameterScheduler
from prt_rl.common.progress_bar import ProgressBar
from prt_rl.common.evaluators import Evaluator
import copy
import numpy as np
from prt_rl.common.collectors import Collector
from prt_rl.common.buffers import ReplayBuffer
from prt_rl.common.components.heads import DistributionHead, QValueHead
from prt_rl.common.components.networks import QCritic
import prt_rl.common.utils as utils
[docs]
@dataclass
class SACConfig:
"""
Hyperparameter configuration for the SAC agent.
Args:
buffer_size (int): Size of the replay buffer.
min_buffer_size (int): Minimum number of transitions in the replay buffer before training starts.
steps_per_batch (int): Number of steps to collect per training batch.
mini_batch_size (int): Size of the mini-batch sampled from the replay buffer for training.
gradient_steps (int): Number of gradient update steps to perform after each batch of experience is collected.
learning_rate (float): Learning rate for the optimizers.
tau (float): Soft update coefficient for the target networks.
gamma (float): Discount factor for future rewards.
entropy_coeff (float | None): Initial value for the entropy coefficient, alpha. If None, it will be learned.
target_entropy (float | None): Target entropy for the policy. A reasonable default is -action_dim.
use_log_entropy (bool): If True, optimize the log of the entropy coefficient, else optimize the coefficient directly.
reward_scale (float): Scaling factor for rewards.
"""
target_entropy: float
buffer_size: int = 1000000
min_buffer_size: int = 100
steps_per_batch: int = 1
mini_batch_size: int = 256
gradient_steps: int = 1
learning_rate: float = 3e-4
tau: float = 0.005
gamma: float = 0.99
entropy_coeff: Optional[float] = None
use_log_entropy: bool = True
reward_scale: float = 1.0
[docs]
class SACPolicy(NeuralPolicy):
"""
Soft Actor-Critic (SAC) policy class.
The default actor is a DistributionPolicy with a TanhGaussian distribution,
and the default critic is a StateActionCritic with 2 critics.
Args:
env_params (EnvParams): Environment parameters.
num_critics (int): Number of critics to use in the SAC algorithm.
actor (DistributionPolicy | None): Actor policy. If None, a default DistributionPolicy will be created.
critic (StateActionCritic | None): Critic network. If None, a default StateActionCritic will be created.
device (str): Device to run the model on (e.g., 'cpu' or 'cuda').
"""
def __init__(self,
network: nn.Module,
actor_head: DistributionHead,
critic_head: QValueHead,
*,
action_min: Tensor,
action_max: Tensor,
num_critics: int = 2,
critic_network: Optional[nn.Module] = None,
) -> None:
super().__init__()
self.num_critics = num_critics
self.action_min = action_min
self.action_max = action_max
self.action_scale = (action_max - action_min) / 2
self.action_bias = (action_max + action_min) / 2
self.actor_network = network
self.actor_head = actor_head
# Create a separate critic network backbone by default if one is not provided
critic_network = critic_network if critic_network is not None else copy.deepcopy(network)
# Create critics and target critics
self.critics = nn.ModuleList()
self.target_critics = nn.ModuleList()
for _ in range(self.num_critics):
critic = QCritic(copy.deepcopy(critic_network), copy.deepcopy(critic_head))
target_critic = copy.deepcopy(critic)
self.critics.append(critic)
self.target_critics.append(target_critic)
[docs]
@torch.no_grad()
def act(self,
obs: torch.Tensor,
deterministic: bool = False
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Predict the action based on the current state.
Args:
state (torch.Tensor): Current state tensor.
deterministic (bool): If True, choose the action deterministically. Default is False.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the chosen action, value estimate, and action log probability.
- action (torch.Tensor): Tensor with the chosen action. Shape (B, action_dim)
- log_prob (torch.Tensor): None
"""
latent = self.actor_network(obs)
action, log_prob, _ = self.actor_head.sample(latent, deterministic=deterministic)
# Rescale the action to the environment's action space
action = action * self.action_scale.to(obs.device) + self.action_bias.to(obs.device)
return action, {"log_prob": log_prob}
[docs]
def get_q_values(self,
obs: torch.Tensor,
action: torch.Tensor,
index: Optional[int] = None
) -> torch.Tensor:
"""
Get Q-values from all critics for the given state-action pairs.
Args:
obs (torch.Tensor): Current observation tensor.
action (torch.Tensor): Action tensor.
Returns:
torch.Tensor: Tensor containing Q-values from all critics. Shape (B, C, 1) where C is the number of critics.
"""
if index is None:
q_values = [critic(obs, action) for critic in self.critics]
q_values = torch.stack(q_values, dim=1) # Shape (B, C, 1) where C is the number of critics
else:
q_values = self.critics[index](obs, action)
return q_values
[docs]
def get_target_q_values(self,
obs: torch.Tensor,
action: torch.Tensor,
) -> torch.Tensor:
"""
Get target Q-values from all target critics for the given state-action pairs.
Args:
obs (torch.Tensor): Current observation tensor.
action (torch.Tensor): Action tensor.
Returns:
torch.Tensor: Tensor containing target Q-values from all critics. Shape (B, C, 1) where C is the number of critics.
"""
q_values = [critic(obs, action) for critic in self.target_critics]
q_values = torch.stack(q_values, dim=1) # Shape (B, C, 1) where C is the number of critics
return q_values
[docs]
class SACAgent(Agent):
"""
Soft Actor-Critic (SAC) agent.
Args:
policy (SACPolicy | None): Policy to use. If None, a default SACPolicy will be created.
config (SACConfig): Configuration for the SAC agent.
device (str): Device to run the model on (e.g., 'cpu' or 'cuda').
References:
[1] https://arxiv.org/pdf/1812.05905
"""
def __init__(self,
policy: SACPolicy,
config: SACConfig,
*,
device: str = 'cpu'
) -> None:
super().__init__()
self.config = config
self.device = torch.device(device)
# Construct a default policy is one is not provided
self.policy = policy
self.policy.to(self.device)
# Initialize the entropy coefficient and target
if self.config.entropy_coeff is None:
if self.config.use_log_entropy:
self.entropy_coeff = torch.log(torch.ones(1, device=self.device)).requires_grad_(True)
else:
self.entropy_coeff = torch.tensor(0.0, requires_grad=True, device=self.device)
self.entropy_optimizer = torch.optim.Adam([self.entropy_coeff], lr=self.config.learning_rate)
else:
self.entropy_coeff = torch.tensor(self.config.entropy_coeff, device=self.device)
self.entropy_optimizer = None
# Configure the optimizers
self.actor_optimizer = torch.optim.Adam(itertools.chain(self.policy.actor_network.parameters(), self.policy.actor_head.parameters()), lr=self.config.learning_rate)
self.critic_optimizer = torch.optim.Adam(self.policy.critics.parameters(), lr=self.config.learning_rate)
[docs]
@torch.no_grad()
def act(self, obs: Tensor, deterministic: bool = False) -> Tensor:
action, _ = self.policy.act(obs, deterministic=deterministic)
return action
[docs]
def train(self,
env: EnvironmentInterface,
total_steps: int,
schedulers: Optional[List[ParameterScheduler]] = None,
logger: Optional[Logger] = None,
evaluator: Optional[Evaluator] = None,
show_progress: bool = True
) -> None:
"""
Train the SAC agent.
Args:
env (EnvironmentInterface): The environment to train on.
total_steps (int): Total number of environment steps to train for.
schedulers (List[ParameterScheduler] | None): List of parameter schedulers to update during training.
logger (Logger | None): Logger for logging training metrics. If None, a default logger will be created.
evaluator (Evaluator | None): Evaluator for periodic evaluation during training.
show_progress (bool): If True, display a progress bar during training.
"""
logger = logger or Logger()
evaluator = evaluator or Evaluator()
if show_progress:
progress_bar = ProgressBar(total_steps=total_steps)
num_steps = 0
collector = Collector(env=env, logger=logger, flatten=True)
replay_buffer = ReplayBuffer(capacity=self.config.buffer_size, device=self.device)
# Collect initial experience until replay buffer has enough samples for training with policy so we have log probabilities
while replay_buffer.get_size() < self.config.min_buffer_size:
experience = collector.collect_experience(policy=self.policy, num_steps=self.config.steps_per_batch)
# Apply reward scaling
experience['reward'] = experience['reward'] * self.config.reward_scale
replay_buffer.add(experience)
num_steps += experience['state'].shape[0]
if show_progress:
progress_bar.update(current_step=num_steps, desc="Collecting initial experience...")
while num_steps < total_steps:
self._update_schedulers(schedulers=schedulers, step=num_steps)
# Collect experience dictionary with shape (B, ...)
experience = collector.collect_experience(policy=self.policy, num_steps=self.config.steps_per_batch, bootstrap=False)
num_steps += experience['state'].shape[0]
# Apply reward scaling
experience['reward'] = experience['reward'] * self.config.reward_scale
# Add experience to the replay buffer
replay_buffer.add(experience)
actor_losses = []
critics_losses = []
entropy_losses = []
for _ in range(self.config.gradient_steps):
# Sample a mini-batch from the replay buffer
mini_batch = replay_buffer.sample(batch_size=self.config.mini_batch_size)
# Compute the current policy's action and log probability
current_action, action_info = self.policy.act(mini_batch['state'])
current_log_prob = action_info['log_prob']
# Entropy coefficient optimization
if self.config.use_log_entropy:
entropy_coeff = torch.exp(self.entropy_coeff.detach())
else:
entropy_coeff = self.entropy_coeff
if self.entropy_optimizer is not None:
entropy_loss = -(self.entropy_coeff * (current_log_prob + self.config.target_entropy).detach()).mean()
entropy_losses.append(entropy_loss.item())
self.entropy_optimizer.zero_grad()
entropy_loss.backward()
self.entropy_optimizer.step()
# Compute the target values from the current policy
with torch.no_grad():
# Select next action based on current policy
next_action, action_info = self.policy.act(mini_batch['next_state'])
next_log_prob = action_info['log_prob']
# Compute the Q-values for all critics using target networks
next_q_values = self.policy.get_target_q_values(obs=mini_batch['next_state'], action=next_action).squeeze(-1)
next_q_values = torch.min(next_q_values, dim=1, keepdim=True)[0]
# Add the entropy term to the target Q-values
next_q_values += -entropy_coeff * next_log_prob
# Compute the discounted target Q-values
y = mini_batch['reward'] + (1 - mini_batch['done'].float()) * self.config.gamma * next_q_values
# Sum the losses across all critics
qs = [self.policy.get_q_values(mini_batch['state'].detach(), mini_batch['action'].detach(), index=i) for i in range(self.policy.num_critics)]
critic_loss = sum(F.mse_loss(y, q) for q in qs)
critics_losses.append(critic_loss.item())
# Take a gradient step on the critics
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Compute Actor loss
q_values_pi = self.policy.get_q_values(obs=mini_batch['state'], action=current_action)
q_values_pi = torch.min(q_values_pi, dim=1, keepdim=True)[0]
actor_loss = (entropy_coeff * current_log_prob - q_values_pi).mean()
actor_losses.append(actor_loss.item())
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update target critic networks
for i in range(self.policy.num_critics):
utils.polyak_update(self.policy.target_critics[i], self.policy.critics[i], tau=self.config.tau)
if show_progress:
tracker = collector.get_metric_tracker()
progress_bar.update(current_step=num_steps, desc=f"Episode Reward: {tracker.last_episode_reward:.2f}"
f" Episode Length: {tracker.last_episode_length}"
f" Episode number: {tracker.episode_count}"
f" Actor Loss: {np.mean(actor_losses):.4f}"
f" Entropy Coef: {entropy_coeff.item():.2f}"
)
if logger.should_log(num_steps):
logger.log_scalar('actor_loss', np.mean(actor_losses), num_steps)
logger.log_scalar('entropy_loss', np.mean(entropy_losses), num_steps)
logger.log_scalar('entropy_coeff', entropy_coeff.item(), num_steps)
logger.log_scalar(f'critic_loss', critic_loss.item(), num_steps)
evaluator.evaluate(agent=self.policy, iteration=num_steps)
evaluator.close()
def _save_impl(self, path: Path) -> None:
"""
Writes a self-contained checkpoint directory.
Layout:
path/
agent.pt
policy.pt
"""
path.mkdir(parents=True, exist_ok=True)
payload = {
"algo": "SAC",
"agent_format_version": 1,
"config": asdict(self.config),
"actor_optimizer_state_dict": self.actor_optimizer.state_dict(),
"critic_optimizer_state_dict": self.critic_optimizer.state_dict(),
}
if self.entropy_optimizer is not None:
payload["entropy_optimizer_state_dict"] = self.entropy_optimizer.state_dict()
torch.save(payload, path / "agent.pt")
self.policy.save(path / "policy.pt")
[docs]
@classmethod
def load(cls, path: str | Path, map_location: str | torch.device = "cpu") -> "SACAgent":
"""
Loads the checkpoint and returns a fully-constructed SACAgent.
"""
p = Path(path)
agent_meta = torch.load(p / "agent.pt", map_location=map_location, weights_only=False)
if agent_meta.get("algo") != "SAC":
raise ValueError(f"Checkpoint algo mismatch: expected SAC, got {agent_meta.get('algo')}")
config = SACConfig(**agent_meta["config"])
policy = SACPolicy.load(p / "policy.pt", map_location=map_location)
agent = cls(
policy=policy,
config=config,
device=str(map_location),
)
actor_opt_state = agent_meta["actor_optimizer_state_dict"]
critic_opt_state = agent_meta["critic_optimizer_state_dict"]
agent.actor_optimizer.load_state_dict(actor_opt_state)
agent.critic_optimizer.load_state_dict(critic_opt_state)
if "entropy_optimizer_state_dict" in agent_meta:
entropy_opt_state = agent_meta["entropy_optimizer_state_dict"]
if agent.entropy_optimizer is not None:
agent.entropy_optimizer.load_state_dict(entropy_opt_state)
else:
print("Warning: checkpoint has entropy optimizer state but current agent does not have an entropy optimizer. Skipping loading entropy optimizer state.")
return agent