"""
Twin Delayed Deep Deterministic Policy Gradient (TD3)
"""
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from typing import Optional, List, Tuple, Dict
from prt_rl.agent import Agent
from prt_rl.common.policies import NeuralPolicy, RandomPolicy
from prt_rl.env.interface import EnvParams, EnvironmentInterface
from prt_rl.common.loggers import Logger
from prt_rl.common.schedulers import ParameterScheduler
from prt_rl.common.progress_bar import ProgressBar
from prt_rl.common.evaluators import Evaluator
import prt_rl.common.utils as utils
import copy
import numpy as np
from dataclasses import dataclass, asdict
from prt_rl.common.collectors import Collector
from prt_rl.common.buffers import ReplayBuffer
from prt_rl.common.components.heads import ContinuousHead, QValueHead
from prt_rl.common.components.networks import QCritic
[docs]
@dataclass
class TD3Config:
"""
Configuration for the TD3 agent.
Args:
buffer_size (int): Size of the replay buffer.
min_buffer_size (int): Minimum size of the replay buffer before training starts.
steps_per_batch (int): Number of steps to collect per batch.
mini_batch_size (int): Size of the mini-batch sample for each gradient update.
gradient_steps (int): Number of gradient steps to take per training iteration.
learning_rate (float): Learning rate for the optimizer.
gamma (float): Discount factor for future rewards.
exploration_noise (float): Standard deviation of Gaussian noise added to actions for exploration.
policy_noise (float): Standard deviation of noise added to the target policy's actions.
noise_clip (float): Maximum absolute value of noise added to the target policy's actions.
delay_freq (int): Frequency of delayed policy updates.
tau (float): Polyak averaging factor for target networks.
num_critics (int): Number of critic networks to use.
"""
buffer_size: int = 100000
min_buffer_size: int = 1000
steps_per_batch: int = 1
mini_batch_size: int = 256
gradient_steps: int = 1
learning_rate: float = 1e-3
gamma: float = 0.99
exploration_noise: float = 0.1
policy_noise: float = 0.2
noise_clip: float = 0.5
delay_freq: int = 2
tau: float = 0.005
[docs]
class TD3Policy(NeuralPolicy):
"""
TD3 Policy
This class implements the TD3 policy, which consists of an actor network and multiple critic networks.
The actor network is used to select actions, while the critic networks are used to evaluate the actions.
The policy can share the encoder with the actor and critic networks if specified.
Args:
env_params (EnvParams): Environment parameters.
num_critics (int): Number of critic networks to use. Default is 2.
actor (Optional[ContinuousPolicy]): Custom actor network. If None, a default actor will be created.
critic (Optional[StateActionCritic]): Custom critic network. If None, a default critic will be created.
share_encoder (bool): Whether to share the encoder between actor and critic networks. Default is True.
device (str): Device to run the policy on ('cpu' or 'cuda'). Default is 'cpu'.
"""
def __init__(self,
network: nn.Module,
actor_head: ContinuousHead,
critic_head: QValueHead,
*,
action_min: Tensor,
action_max: Tensor,
num_critics: int = 2,
exploration_noise: float = 0.1,
critic_network: Optional[nn.Module] = None,
) -> None:
super().__init__()
self.num_critics = num_critics
self.action_min = action_min
self.action_max = action_max
self.exploration_noise = exploration_noise
# Create an unified actor network and target actor network
self.actor = nn.Sequential(network, actor_head)
self.target_actor = copy.deepcopy(self.actor)
# Create a separate critic network backbone by default if one is not provided
critic_network = critic_network if critic_network is not None else copy.deepcopy(network)
# Create critics and target critics
self.critics = nn.ModuleList()
self.target_critics = nn.ModuleList()
for _ in range(self.num_critics):
critic = QCritic(copy.deepcopy(critic_network), copy.deepcopy(critic_head))
target_critic = copy.deepcopy(critic)
self.critics.append(critic)
self.target_critics.append(target_critic)
[docs]
@torch.no_grad()
def act(self,
obs: torch.Tensor,
deterministic: bool = False
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
action = self.actor(obs)
if not deterministic:
# Add noise to the action for exploration
noise = utils.gaussian_noise(mean=0, std=self.exploration_noise, shape=action.shape, device=self.device)
action = action + noise
# Ensure action is within bounds
action = action.clamp(self.action_min.to(self.device), self.action_max.to(self.device))
return action, {}
[docs]
def target_actor_action(self, obs:Tensor, policy_noise: float, noise_clip: float, action_shape) -> torch.Tensor:
"""
Compute the target actor's action with added noise for policy smoothing.
Args:
obs (torch.Tensor): The current observation of the environment.
policy_noise (float): Standard deviation of noise added to the target policy's actions.
noise_clip (float): Maximum absolute value of noise added to the target policy's actions.
action_shape: Shape of the action tensor.
Returns:
torch.Tensor: The action computed by the target actor with added noise, clipped to action bounds.
"""
# Generate additive Gaussian noise and clip it to the specified range
noise = utils.gaussian_noise(mean=0, std=policy_noise, shape=action_shape, device=self.device)
noise_clipped = noise.clamp(-noise_clip, noise_clip)
# Get target actor action plus noise and clip to action bounds
action = self.target_actor(obs) + noise_clipped
action = action.clamp(self.action_min.to(self.device), self.action_max.to(self.device))
return action
[docs]
def get_q_values(self,
obs: torch.Tensor,
action: torch.Tensor,
index: Optional[int] = None
) -> torch.Tensor:
"""
Get Q-values from all critics for the given state-action pairs.
Args:
obs (torch.Tensor): Current observation tensor.
action (torch.Tensor): Action tensor.
Returns:
torch.Tensor: Tensor containing Q-values from all critics. Shape (B, C, 1) where C is the number of critics.
"""
if index is None:
q_values = [critic(obs, action) for critic in self.critics]
q_values = torch.stack(q_values, dim=1) # Shape (B, C, 1) where C is the number of critics
else:
q_values = self.critics[index](obs, action)
return q_values
[docs]
def get_target_q_values(self,
obs: torch.Tensor,
action: torch.Tensor,
) -> torch.Tensor:
"""
Get target Q-values from all target critics for the given state-action pairs.
Args:
obs (torch.Tensor): Current observation tensor.
action (torch.Tensor): Action tensor.
Returns:
torch.Tensor: Tensor containing target Q-values from all critics. Shape (B, C, 1) where C is the number of critics.
"""
q_values = [critic(obs, action) for critic in self.target_critics]
q_values = torch.stack(q_values, dim=1) # Shape (B, C, 1) where C is the number of critics
return q_values
[docs]
class TD3Agent(Agent):
"""
Twin Delayed Deep Deterministic Policy Gradient (TD3)
This class implements the TD3 algorithm, which is an off-policy actor-critic algorithm for continuous action spaces.
Args:
policy (TD3Policy | None): Custom TD3 policy. If None, a default TD3 policy will be created.
config (TD3Config): Configuration for the TD3 agent.
device (str): Device to run the agent on ('cpu' or 'cuda'). Default is 'cpu'.
"""
def __init__(self,
policy: TD3Policy,
config: TD3Config = TD3Config(),
*,
device: str = 'cpu',
) -> None:
super().__init__()
self.config = config
self.device = torch.device(device)
self.policy = policy
self.policy.to(self.device)
self.actor_optimizer = torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.learning_rate)
self.critic_optimizer = torch.optim.Adam(self.policy.critics.parameters(), lr=self.config.learning_rate)
[docs]
@torch.no_grad()
def act(self, obs: Tensor, deterministic: bool = False) -> Tensor:
"""
Perform an action based on the current state.
Args:
obs (torch.Tensor): The current observation of the environment.
deterministic (bool): If True, the agent will select actions deterministically.
Returns:
torch.Tensor: The action to be taken.
"""
action, _ = self.policy.act(obs, deterministic=deterministic)
return action
[docs]
def train(self,
env: EnvironmentInterface,
total_steps: int,
schedulers: Optional[List[ParameterScheduler]] = None,
logger: Optional[Logger] = None,
evaluator: Optional[Evaluator] = None,
show_progress: bool = True
) -> None:
"""
Update the agent's knowledge based on the action taken and the received reward.
This method should implement the TD3 training loop.
Args:
env: The environment to interact with.
total_steps: Total number of steps to train the agent.
schedulers: Optional list of parameter schedulers.
logger: Optional logger for logging training progress.
evaluator: Evaluator for evaluating the agent's performance.
show_progress: If True, show a progress bar during training.
"""
logger = logger or Logger()
evaluator = evaluator or Evaluator()
if show_progress:
progress_bar = ProgressBar(total_steps=total_steps)
num_steps = 0
num_gradient_steps = 0
# Make collector and flatten the experience so the shape is (B, ...)
collector = Collector(env=env, logger=logger, flatten=True)
replay_buffer = ReplayBuffer(capacity=self.config.buffer_size, device=self.device)
# Collect initial experience until replay buffer has enough samples for training with random policy
random_policy = RandomPolicy(env.get_parameters())
while replay_buffer.get_size() < self.config.min_buffer_size:
experience = collector.collect_experience(policy=random_policy, num_steps=self.config.steps_per_batch)
replay_buffer.add(experience)
num_steps += experience['state'].shape[0]
if show_progress:
progress_bar.update(current_step=num_steps, desc="Collecting initial experience...")
while num_steps < total_steps:
self._update_schedulers(schedulers, num_steps)
# Collect experience dictionary with shape (B, ...)
experience = collector.collect_experience(policy=self.policy, num_steps=self.config.steps_per_batch)
num_steps += experience['state'].shape[0]
# Store experience in replay buffer
replay_buffer.add(experience)
actor_losses = []
critics_losses = []
for _ in range(self.config.gradient_steps):
num_gradient_steps += 1
# Sample a batch of experiences from the replay buffer
batch = replay_buffer.sample(batch_size=self.config.mini_batch_size)
# Compute current policy's action and target
# We compute the target y values without gradients because they will be used to compute the loss for each critic
# so an error will be raised for trying to backpropagate through y more than once.
with torch.no_grad():
# Compute the policies next action with noise and clip to ensure it does not exceed action bounds - [B, A]
next_action = self.policy.target_actor_action(
obs=batch['next_state'],
policy_noise=self.config.policy_noise,
noise_clip=self.config.noise_clip,
action_shape=batch['action'].shape
)
# Compute the Q-Values for all the critics - shape (B, C, 1) -> (B, C)
next_q_values = self.policy.get_target_q_values(batch['next_state'], next_action).squeeze(-1)
# Use the minimum Q-Value across critics for the target
next_q_values = torch.min(next_q_values, dim=1, keepdim=True)[0]
# Compute the target Q-Value
y = batch['reward'] + self.config.gamma * (1 - batch['done'].float()) * next_q_values
# Sum the losses across all critics
qs = [self.policy.get_q_values(batch['state'].detach(), batch['action'].detach(), index=i) for i in range(self.policy.num_critics)]
critic_loss = sum(F.mse_loss(y, q) for q in qs)
critics_losses.append(critic_loss.item())
# Take a gradient step on the critics
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Delayed policy update
if num_gradient_steps % self.config.delay_freq == 0:
# Compute actor loss
actor_loss = -self.policy.get_q_values(obs=batch['state'], action=self.policy.actor(batch['state']), index=0).mean()
actor_losses.append(actor_loss.item())
# Take a gradient step on the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update target networks
utils.polyak_update(self.policy.target_actor, self.policy.actor, tau=self.config.tau)
for i in range(self.policy.num_critics):
utils.polyak_update(self.policy.target_critics[i], self.policy.critics[i], tau=self.config.tau)
if show_progress:
tracker = collector.get_metric_tracker()
progress_bar.update(current_step=num_steps, desc=f"Episode Reward: {tracker.last_episode_reward:.2f}"
f" Episode Length: {tracker.last_episode_length}"
f" Episode number: {tracker.episode_count}"
f" Actor Loss: {np.mean(actor_losses):.4f}"
)
if logger.should_log(num_steps):
logger.log_scalar('actor_loss', np.mean(actor_losses), num_steps)
logger.log_scalar(f'critic_loss', critic_loss.detach().cpu().item(), num_steps)
evaluator.evaluate(agent=self.policy, iteration=num_steps)
evaluator.close()
def _save_impl(self, path: Path) -> None:
"""
Writes a self-contained checkpoint directory.
Layout:
path/
agent.pt
policy.pt
"""
path.mkdir(parents=True, exist_ok=True)
payload = {
"algo": "TD3",
"agent_format_version": 1,
"config": asdict(self.config),
"actor_optimizer_state_dict": self.actor_optimizer.state_dict(),
"critic_optimizer_state_dict": self.critic_optimizer.state_dict(),
}
torch.save(payload, path / "agent.pt")
self.policy.save(path / "policy.pt")
[docs]
@classmethod
def load(cls, path: str | Path, map_location: str | torch.device = "cpu") -> "TD3Agent":
"""
Loads the checkpoint and returns a fully-constructed TD3Agent.
"""
p = Path(path)
agent_meta = torch.load(p / "agent.pt", map_location=map_location, weights_only=False)
if agent_meta.get("algo") != "TD3":
raise ValueError(f"Checkpoint algo mismatch: expected TD3, got {agent_meta.get('algo')}")
config = TD3Config(**agent_meta["config"])
policy = TD3Policy.load(p / "policy.pt", map_location=map_location)
agent = cls(
policy=policy,
config=config,
device=str(map_location),
)
actor_opt_state = agent_meta["actor_optimizer_state_dict"]
critic_opt_state = agent_meta["critic_optimizer_state_dict"]
agent.actor_optimizer.load_state_dict(actor_opt_state)
agent.critic_optimizer.load_state_dict(critic_opt_state)
return agent