Source code for prt_rl.common.policies.random
"""
Random Policy that samples actions uniformly from the action space.
"""
import torch
from torch import Tensor
from typing import Union, Dict, Tuple
from prt_rl.env.interface import EnvParams, MultiAgentEnvParams
[docs]
class RandomPolicy:
"""
Implements a policy that uniformly samples random actions.
This policy implements the Policy protocol so it can be used with any Collector or Evaluator in the PRT-RL framework.
Args:
env_params (EnvParams): environment parameters
"""
def __init__(self,
env_params: Union[EnvParams | MultiAgentEnvParams],
) -> None:
self.env_params = env_params
[docs]
@torch.no_grad()
def act(self,
obs: torch.Tensor,
deterministic: bool = False
) -> Tuple[Tensor, Dict[str, Tensor]]:
"""
Randomly samples an action from action space.
Returns:
TensorDict: Tensordict with the "action" key added
"""
if deterministic:
raise ValueError("RandomPolicy does not support deterministic actions. Set deterministic=False to sample random actions.")
if isinstance(self.env_params, EnvParams):
ashape = (obs.shape[0], self.env_params.action_len)
params = self.env_params
elif isinstance(self.env_params, MultiAgentEnvParams):
ashape = (obs.shape[0], self.env_params.num_agents, self.env_params.agent.action_len)
params = self.env_params.agent
else:
raise ValueError("env_params must be a EnvParams or MultiAgentEnvParams")
if not params.action_continuous:
# Add 1 to the high value because randint samples between low and 1 less than the high: [low,high)
action = torch.randint(low=params.action_min, high=params.action_max + 1,
size=ashape)
else:
action = torch.rand(size=ashape)
# Scale the random [0,1] actions to the action space [min,max]
max_actions = torch.tensor(params.action_max).unsqueeze(0)
min_actions = torch.tensor(params.action_min).unsqueeze(0)
action = action * (max_actions - min_actions) + min_actions
return action, {}