pretrained

pretrained#

Classes#

class prt_rl.common.policies.pretrained.SB3Policy(model_dir: str, model_type: str, env_name: str, device: str = 'cpu', **kwargs)[source]#

Stable Baselines3 (SB3) agent for reinforcement learning.

This agent wraps a pre-trained model from Stable Baselines3 and uses it to predict actions based on the current state.

Note

You must install prt-rl[sb3] to use this agent, which includes the necessary dependencies for Stable Baselines3.

Parameters:
  • model_dir (str) – Path to the pre-trained model file.

  • model_type (str) – Type of the model (e.g., ‘ppo’, ‘dqn’, ‘sac’, etc.).

  • device (str) – Device to run the model on (‘cpu’ or ‘cuda’). Default is ‘cpu’.

  • **kwargs – Additional keyword arguments to pass to the model loading function.

Reference:

[1] [Model Library](https://huggingface.co/sb3)

act(obs: Tensor, deterministic: bool = True) Tuple[Tensor, Dict[str, Tensor]][source]#

Perform an action based on the current observation.

Parameters:

obs – The current observation of the environment.

Returns:

The action to be taken.