Source code for prt_rl.common.policies.pretrained

import torch
from torch import Tensor
from typing import Dict, Tuple

[docs] class SB3Policy: """ Stable Baselines3 (SB3) agent for reinforcement learning. This agent wraps a pre-trained model from Stable Baselines3 and uses it to predict actions based on the current state. Note: You must install prt-rl[sb3] to use this agent, which includes the necessary dependencies for Stable Baselines3. Args: model_dir (str): Path to the pre-trained model file. model_type (str): Type of the model (e.g., 'ppo', 'dqn', 'sac', etc.). device (str): Device to run the model on ('cpu' or 'cuda'). Default is 'cpu'. **kwargs: Additional keyword arguments to pass to the model loading function. Reference: [1] [Model Library](https://huggingface.co/sb3) """ def __init__(self, model_dir: str, model_type: str, env_name: str, device: str = "cpu", **kwargs ) -> None: self.device = torch.device(device) try: from sb3_contrib import ARS, QRDQN, TQC, TRPO, CrossQ, RecurrentPPO from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.base_class import BaseAlgorithm from rl_zoo3.load_from_hub import download_from_hub from rl_zoo3.utils import get_model_path ALGOS: dict[str, type[BaseAlgorithm]] = { "a2c": A2C, "ddpg": DDPG, "dqn": DQN, "ppo": PPO, "sac": SAC, "td3": TD3, # SB3 Contrib, "ars": ARS, "crossq": CrossQ, "qrdqn": QRDQN, "tqc": TQC, "trpo": TRPO, "ppo_lstm": RecurrentPPO, } if model_type.lower() not in ALGOS: raise ValueError(f"Model type '{model_type}' is not supported. Supported types are: {list(ALGOS.keys())}") self.model_type = model_type.lower() self.model_class = ALGOS[self.model_type] except ImportError: raise ImportError("Please install prt-rl with stable-baselines3 support with pip install prt-rl[sb3].") # Get the model path and if it does not exist, download it from the hub try: _, self.model_path, _ = get_model_path(exp_id=0, folder=model_dir, algo=self.model_type, env_name=env_name) except (AssertionError, ValueError) as e: download_from_hub( algo=self.model_type, env_name=env_name, exp_id=0, folder=model_dir, organization="sb3", repo_name=None, force=False ) _, self.model_path, _ = get_model_path(exp_id=0, folder=model_dir, algo=self.model_type, env_name=env_name) # Load the model self.model = self.model_class.load(self.model_path, device=self.device, **kwargs)
[docs] @torch.no_grad() def act(self, obs: torch.Tensor, deterministic: bool = True) -> Tuple[Tensor, Dict[str, Tensor]]: """ Perform an action based on the current observation. Args: obs: The current observation of the environment. Returns: The action to be taken. """ # Move state to model device obs = obs.to(self.device) # Wrap state in a batch if needed (for SB3, it expects NumPy) if isinstance(obs, torch.Tensor): obs = obs.cpu().numpy() action, _ = self.model.predict(obs, deterministic=deterministic) if len(action.shape) == 1: # Discrete actions are returned with shape (batch_size,) action = torch.tensor(action, device=self.device).unsqueeze(-1) else: action = torch.tensor(action, device=self.device) return action, {}