Source code for prt_rl.agent

"""
Agent Interface for implementing new agents.
"""
from abc import ABC, abstractmethod
from pathlib import Path
import tempfile
import torch
from typing import Optional, Union, List
from prt_rl.common.schedulers import ParameterScheduler
from prt_rl.common.loggers import Logger
import prt_rl.common.utils as utils


[docs] class Agent(ABC): """ Base class for all agents in the PRT-RL framework. """ def __init__(self, device: str = "cpu" ) -> None: self.device = torch.device(device)
[docs] @torch.no_grad() def act(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor: """ Perform an action based on the current state. Args: obs (torch.Tensor): The current observation from the environment. deterministic (bool): If True, the agent will select actions deterministically. Returns: torch.Tensor: The action to be taken. """ raise NotImplementedError("The act method must be implemented by subclasses.")
def save(self, path: Optional[Union[str, Path]] = None) -> Path: if path is None: tmp = tempfile.TemporaryDirectory(prefix="prt_rl_ckpt_") path = Path(tmp.name) # keep reference so directory is not deleted immediately self._last_tmp_checkpoint = tmp else: path = Path(path) path.mkdir(parents=True, exist_ok=True) self._save_impl(path) return path @abstractmethod def _save_impl(self, path: Path) -> None: raise NotImplementedError("The _save_impl method must be implemented by subclasses.") @classmethod @abstractmethod def load(cls, path: str | Path, map_location: str | torch.device = "cpu") -> "Agent": raise NotImplementedError("The load method must be implemented by subclasses.") # -------------- # Helper Methods # -------------- @classmethod def _update_schedulers( cls, schedulers: Optional[List[ParameterScheduler]] = None, step: int = 0, logger: Logger | None = None ) -> None: """ Update a list of parameter schedulers to the current step. Args: schedulers (Optional[List[ParameterScheduler]]): List of schedulers to update. Each scheduler should have an update(current_step: int) method. Default is None. step (int): The current step to update the schedulers to. Scalar. Returns: None """ if schedulers is not None: for scheduler in schedulers: scheduler.update(current_step=step) if (logger is not None) and (logger.should_log(step)): logger.log_scalar(name=scheduler.parameter_name, value=scheduler.get_value(), iteration=step) @classmethod def _update_optimizer( cls, optimizer: object, learning_rate: float ) -> None: """ Update the learning rate for all parameter groups in an optimizer. Args: optimizer (object): Optimizer object (e.g., torch.optim.Optimizer) with a param_groups attribute (list of dicts with 'lr' key). learning_rate (float): New learning rate to set. Scalar. Returns: None """ for param_group in optimizer.param_groups: param_group['lr'] = learning_rate @classmethod def _compute_gae( cls, experience: dict, gamma: float, gae_lambda: float ) -> dict: """ Compute Generalized Advantage Estimation (GAE) and returns, and flatten the experience batch. Args: experience (dict): Dictionary with keys: 'reward' (torch.Tensor): Rewards, shape (T, N, 1) or (B, 1) 'value' (torch.Tensor): State values, shape (T, N, 1) or (B, 1) 'done' (torch.Tensor): Done flags, shape (T, N, 1) or (B, 1) 'last_value_est' (torch.Tensor): Value estimates for final state, shape (N, 1) gamma (float): Discount factor. Scalar. gae_lambda (float): GAE lambda. Scalar. Returns: dict: Experience dict with added keys: 'advantages' (torch.Tensor): Estimated advantages, shape (N*T, ...) 'returns' (torch.Tensor): TD(lambda) returns, shape (N*T, ...) All other tensors are flattened to (N*T, ...). 'last_value_est' is removed. """ # Compute Advantages and Returns under the current policy advantages, returns = utils.generalized_advantage_estimates( rewards=experience['reward'], values=experience['value'], dones=experience['done'], last_values=experience['last_value_est'], gamma=gamma, gae_lambda=gae_lambda ) experience['advantages'] = advantages.detach() experience['returns'] = returns.detach() # Flatten the experience batch (N, T, ...) -> (N*T, ...) and remove the last_value_est key because we don't need it anymore experience = {k: v.reshape(-1, *v.shape[2:]) for k, v in experience.items() if k != 'last_value_est'} return experience