Source code for prt_rl.model_free.ppo

"""
Proximal Policy Optimization (PPO)

Reference:
[1] https://arxiv.org/abs/1707.06347
"""
from dataclasses import dataclass, asdict
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from typing import Optional, List, Tuple, Dict
from prt_rl.agent import Agent
from prt_rl.env.interface import EnvironmentInterface
from prt_rl.common.collectors import Collector
from prt_rl.common.buffers import RolloutBuffer
from prt_rl.common.loggers import Logger
from prt_rl.common.schedulers import ParameterScheduler
from prt_rl.common.progress_bar import ProgressBar
from prt_rl.common.evaluators import Evaluator
import prt_rl.common.utils as utils

from prt_rl.common.policies import NeuralPolicy
from prt_rl.common.components.heads.interface import DistributionHead
from prt_rl.common.components.heads import ValueHead

# @todo Add support for KL stopping
# @todo Add support for value clipping

# Define the Algorithm config dataclass
[docs] @dataclass class PPOConfig: """ Configuration for the PPO agent. Args: steps_per_batch (int): Number of steps to collect per batch. mini_batch_size (int): Size of mini-batches for optimization. learning_rate (float): Learning rate for the optimizer. gamma (float): Discount factor for future rewards. epsilon (float): Clipping parameter for PPO. gae_lambda (float): Lambda parameter for Generalized Advantage Estimation. entropy_coef (float): Coefficient for the entropy term in the loss function. value_coef (float): Coefficient for the value loss term in the loss function. num_optim_steps (int): Number of optimization steps per batch. normalize_advantages (bool): Whether to normalize advantages. """ steps_per_batch: int = 2048 mini_batch_size: int = 32 learning_rate: float = 3e-4 gamma: float = 0.99 epsilon: float = 0.1 gae_lambda: float = 0.95 entropy_coef: float = 0.01 value_coef: float = 0.5 num_optim_steps: int = 10 normalize_advantages: bool = False
# Define the Policy Interface
[docs] class PPOPolicy(NeuralPolicy): """ PPOPolicy is a policy that combines an actor and a critic network. It can optionally use an encoder network to process the input state before passing it to the actor and critic heads. The PPOPolicy is a combination of a DistributionPolicy for the actor and a ValueCritic for the critic. It can handle both discrete and continuous action spaces. The architecture of the policy is as follows: - Encoder Network (optional): Processes the input state. - Actor Head: Computes actions based on the latent state. - Critic Head: Computes the value for the given state. Args: env_params (EnvParams): Environment parameters. encoder (BaseEncoder | None): Encoder network to process the input state. If None, the input state is used directly. actor (DistributionPolicy | None): Actor network to compute actions. If None, a default DistributionPolicy is created. critic (ValueCritic | None): Critic network to compute values. If None, a default ValueCritic is created. share_encoder (bool): If True, share the encoder between actor and critic. Default is False. """ def __init__(self, *, network: nn.Module, actor_head: DistributionHead, critic_head: ValueHead, critic_network: Optional[nn.Module] = None, ) -> None: super().__init__() self.network = network self.actor_head = actor_head self.critic_head = critic_head self.critic_network = critic_network
[docs] @torch.no_grad() def act(self, obs: torch.Tensor, deterministic: bool = False ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Returns action + info dict. Info dict keys (typical): - "log_prob": (B,1) - "value": (B,1) """ latent = self.network(obs) action, log_prob, _ = self.actor_head.sample(latent, deterministic=deterministic) if self.critic_network is not None: latent = self.critic_network(obs) value = self.critic_head(latent) return action, {"log_prob": log_prob, "value": value}
[docs] def forward(self, obs: torch.Tensor, deterministic: bool = False ) -> torch.Tensor: """ Convenience: treat the policy like a normal nn.Module that outputs actions. Collectors should call act() instead to get info dict. """ action, _ = self.act(obs, deterministic=deterministic) return action
[docs] def evaluate_actions(self, obs: torch.Tensor, action: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Used during PPO optimization. Returns: value: (B,1) log_prob: (B,1) entropy: (B,1) """ if self.critic_network is not None: latent = self.critic_network(obs) else: latent = self.network(obs) value = self.critic_head(latent) # Compute log probabilities and entropy for the entire action vector log_prob = self.actor_head.log_prob(latent, action) entropy = self.actor_head.entropy(latent) return value, log_prob, entropy
[docs] def get_state_value(self, obs: torch.Tensor, ) -> torch.Tensor: """ Returns the state value for the given observation. Args: obs: (B, obs_dim) Returns: value: (B,1) """ if self.critic_network is not None: latent = self.critic_network(obs) else: latent = self.network(obs) value = self.critic_head(latent) return value
# Make the Agent
[docs] class PPOAgent(Agent): """ Proximal Policy Optimization (PPO) Args: policy (PPOPolicy): Policy to use. config (PPOConfig): Configuration for the PPO agent. device (str): Device to run the computations on ('cpu' or 'cuda'). """ def __init__(self, policy: PPOPolicy, config: PPOConfig = PPOConfig(), *, device: str = 'cpu', ) -> None: self.config = config self.policy = policy.to(device) super().__init__(device=device) # Configure optimizers self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.config.learning_rate)
[docs] @torch.no_grad() def act(self, obs: Tensor, deterministic: bool = False) -> Tensor: action, _ = self.policy.act(obs, deterministic=deterministic) return action
[docs] def train(self, env: EnvironmentInterface, total_steps: int, schedulers: Optional[List[ParameterScheduler]] = None, logger: Optional[Logger] = None, evaluator: Optional[Evaluator] = None, show_progress: bool = True ) -> None: """ Train the PPO agent. Args: env (EnvironmentInterface): The environment to train on. total_steps (int): Total number of steps to train for. schedulers (Optional[List[ParameterScheduler]]): Learning rate schedulers. logger (Optional[Logger]): Logger for training metrics. evaluator (Optional[Evaluator]): Evaluator for performance evaluation. show_progress (bool): If True, show a progress bar during training. """ logger = logger or Logger() evaluator = evaluator or Evaluator() if show_progress: progress_bar = ProgressBar(total_steps=total_steps) num_steps = 0 # Make collector and do not flatten the experience so the shape is (T, N, ...) collector = Collector(env=env, logger=logger, flatten=False) rollout_buffer = RolloutBuffer(capacity=self.config.steps_per_batch, device=self.device) while num_steps < total_steps: self._update_schedulers(schedulers=schedulers, step=num_steps) self._update_optimizer(self.optimizer, self.config.learning_rate) # Collect experience dictionary with shape (T, N, ...) experience = collector.collect_experience(policy=self.policy, num_steps=self.config.steps_per_batch) # Compute Advantages and Returns under the current policy experience = self._compute_gae(experience, gamma=self.config.gamma, gae_lambda=self.config.gae_lambda) # Update the total number of steps collected so far num_steps += experience['state'].shape[0] # Add experience to the rollout buffer rollout_buffer.add(experience) # Optimization Loop clip_losses = [] entropy_losses = [] value_losses = [] losses = [] for _ in range(self.config.num_optim_steps): for batch in rollout_buffer.get_batches(batch_size=self.config.mini_batch_size): # Treat the previous policy's log probabilities as constant, as well as the advantages and returns old_log_prob = batch['log_prob'] advantages = batch['advantages'] returns = batch['returns'] if self.config.normalize_advantages: advantages = utils.normalize_advantages(advantages) # Get the log probability and entropy of the actions under the current policy new_value_est, new_log_prob, entropy = self.policy.evaluate_actions(batch['state'], batch['action']) # Ratio between new and old policy ratio = torch.exp(new_log_prob - old_log_prob) # Clipped surrogate loss clip_loss = advantages * ratio clip_loss2 = advantages * torch.clamp(ratio, 1 - self.config.epsilon, 1 + self.config.epsilon) clip_loss = -torch.min(clip_loss, clip_loss2).mean() # Compute entropy loss entropy_loss = -entropy.mean() # Compute the value loss function value_loss = F.mse_loss(new_value_est, returns) # Compute total clipped PPO loss loss = clip_loss + self.config.entropy_coef*entropy_loss + self.config.value_coef * value_loss clip_losses.append(clip_loss.item()) entropy_losses.append(entropy_loss.item()) value_losses.append(value_loss.item()) losses.append(loss.item()) # Optimize self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0) self.optimizer.step() # Clear the buffer after optimization rollout_buffer.clear() # Update progress bar if show_progress: tracker = collector.get_metric_tracker() progress_bar.update(current_step=num_steps, desc=f"Episode Reward: {tracker.last_episode_reward:.2f}, " f"Episode Length: {tracker.last_episode_length}, " f"Loss: {np.mean(losses):.4f},") # Log metrics if logger.should_log(num_steps): logger.log_scalar('clip_loss', np.mean(clip_losses), num_steps) logger.log_scalar('entropy_loss', np.mean(entropy_losses), num_steps) logger.log_scalar('value_loss', np.mean(value_losses), num_steps) logger.log_scalar('loss', np.mean(losses), num_steps) # logger.log_scalar('episode_reward', collector.previous_episode_reward, num_steps) # logger.log_scalar('episode_length', collector.previous_episode_length, num_steps) evaluator.evaluate(agent=self.policy, iteration=num_steps) evaluator.close()
def _save_impl(self, path: Path) -> None: """ Writes a self-contained checkpoint directory. Layout: path/ agent.pt policy.pt """ path.mkdir(parents=True, exist_ok=True) payload = { "algo": "PPO", "agent_format_version": 1, "config": asdict(self.config), "optimizer_state_dict": self.optimizer.state_dict(), } torch.save(payload, path / "agent.pt") self.policy.save(path / "policy.pt")
[docs] @classmethod def load(cls, path: str | Path, map_location: str | torch.device = "cpu") -> "PPOAgent": """ Loads the checkpoint and returns a fully-constructed PPOAgent. """ p = Path(path) agent_meta = torch.load(p / "agent.pt", map_location=map_location, weights_only=False) if agent_meta.get("algo") != "PPO": raise ValueError(f"Checkpoint algo mismatch: expected PPO, got {agent_meta.get('algo')}") config = PPOConfig(**agent_meta["config"]) policy = PPOPolicy.load(p / "policy.pt", map_location=map_location) agent = cls( policy=policy, config=config, device=str(map_location), ) opt_state = agent_meta["optimizer_state_dict"] agent.optimizer.load_state_dict(opt_state) return agent