Source code for prt_rl.env.wrappers.gymnasium_envs

"""
Wrapper for Gymnasium environments.
"""
import gymnasium as gym
import numpy as np
import torch
from typing import Optional, Tuple, List, Union, Dict, Any, Callable
from prt_rl.env.interface import EnvironmentInterface, EnvParams

[docs] class GymnasiumWrapper(EnvironmentInterface): """ Wraps the Gymnasium environments in the Environment interface. Args: gym_name: Name of the Gymnasium environment. env_factory: Callable that constructs a Gymnasium environment for each env index. num_envs: Number of parallel environments to create. render_mode: Sets the rendering mode. Defaults to None. Examples: .. code-block:: python from prt_rl.env.wrappers import GymnasiumWrapper from prt_rl.common.policy import RandomPolicy env = GymnasiumWrapper( gym_name="CarRacing-v3", render_mode="rgb_array", continuous=True ) # or use a factory (useful for domain randomization wrappers) env = GymnasiumWrapper( env_factory=lambda env_index, seed: gym.make("Pendulum-v1"), num_envs=4 ) policy = RandomPolicy(env_params=env.get_parameters()) state, info = env.reset() done = False while not done: action = policy.get_action(state) next_state, reward, done, info = env.step(action) """ def __init__(self, gym_name: Optional[str] = None, env_factory: Optional[Callable[[int, Optional[int]], gym.Env]] = None, num_envs: int = 1, render_mode: Optional[str] = None, seed: Optional[int] = None, device: str = 'cpu', **kwargs ) -> None: super().__init__(render_mode, num_envs=num_envs) self.gym_name = gym_name self.env_factory = env_factory self.device = torch.device(device) if (self.gym_name is None) == (self.env_factory is None): raise ValueError("Provide exactly one of `gym_name` or `env_factory`.") def _make_env(env_index: int): env_seed = None if seed is None else seed + env_index if self.env_factory is not None: env = self.env_factory(env_index, env_seed) else: env = gym.make(self.gym_name, render_mode=render_mode, **kwargs) # Seed the environment if a seed is provided if env_seed is not None: env.reset(seed=env_seed) env.action_space.seed(env_seed) env.observation_space.seed(env_seed) return env if self.num_envs == 1: self.env = _make_env(0) vectorized = False else: def make_env_fn(env_index: int): return lambda: _make_env(env_index) self.env = gym.vector.SyncVectorEnv([make_env_fn(i) for i in range(self.num_envs)]) vectorized = True self.env_params = self._make_env_params(vectorized=vectorized)
[docs] def get_parameters(self) -> EnvParams: """ Returns the EnvParams object which contains information about the sizes of observations and actions needed for setting up RL agents. Returns: EnvParams: environment parameters object """ return self.env_params
[docs] def reset(self, seed: int | None = None) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Resets the environment to the initial state and returns the initial observation. Args: seed (int | None): Sets the random seed. Returns: Tuple: Tuple of tensors containing the initial observation and info dictionary """ state, info = self.env.reset(seed=seed) state = self._process_observation(state) if self.render_mode == 'rgb_array': rgb = self.env.render() info['rgb_array'] = rgb[np.newaxis, ...] return state, info
[docs] def reset_index(self, index: int, seed: int | None = None) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Resets only the environments that are done. Args: done (torch.Tensor): Boolean tensor of shape (num_envs, 1) or (num_envs,) Returns: Tuple[torch.Tensor, Dict[str, Any]]: The new observations and info dict """ if index > self.num_envs: raise ValueError(f"Index {index} is out of bounds for {self.num_envs} environments.") # If there is only one environment, reset it directly if self.num_envs == 1: state, info = self.reset(seed=seed) else: state, info = self.env.envs[index].reset(seed=seed) state = self._process_observation(state) if self.render_mode == 'rgb_array': rgb = self.env.render() info['rgb_array'] = rgb[np.newaxis, ...] return state, info
[docs] def step(self, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Steps the simulation using the action tensor and returns the new trajectory. Args: action (torch.Tensor): Tensor with "action" key that is a tensor with shape (# env, # actions) Returns: Tuple: Tuple of tensors containing the next state, reward, done, and info dictionary """ # Discrete actions send the raw integer value to the step function if not self.env_params.action_continuous: if self.num_envs == 1: # If there is only one environment, the step function expects a single integer action action = action.item() else: # If there are multiple environments and 1 action, the step function expects an action with shape (# envs,) action = action.cpu().numpy().squeeze(-1) else: action = action.detach().cpu().numpy() # If there is only one environment remove the first dimension if action.shape[0] == 1: action = action[0] next_state, reward, terminated, trunc, info = self.env.step(action) done = np.logical_or(terminated, trunc) # Reshape the reward and done to be (# envs, 1) if self.num_envs == 1: reward = torch.tensor([[reward]], dtype=torch.float, device=self.device) done = torch.tensor([[bool(done)]], dtype=torch.bool, device=self.device) else: reward = torch.tensor(reward, dtype=torch.float, device=self.device).unsqueeze(-1) done = torch.tensor(done, dtype=torch.bool, device=self.device).unsqueeze(-1) next_state = self._process_observation(next_state) if self.render_mode == 'rgb_array': rgb = self.env.render() info['rgb_array'] = rgb[np.newaxis, ...] return next_state, reward, done, info
[docs] def close(self): return self.env.close()
def _process_observation(self, observation: Union[torch.Tensor | int]) -> torch.Tensor: """ Processes the observation to ensure it is in the correct format. Args: observation (Union[torch.Tensor | int]): The observation to process. Returns: torch.Tensor: The processed observation. """ if isinstance(observation, int): observation = np.array([observation]) # Add a dimension if there is only 1 environment if self.num_envs == 1: observation = torch.tensor(observation, device=self.device).unsqueeze(0) else: observation = torch.tensor(observation, device=self.device) # If observation is float64 convert it to float32 if observation.dtype == torch.float64: observation = observation.float() return observation def _make_env_params(self, vectorized: bool = False, ) -> EnvParams: """ Creates the environment parameters based on the action and observation space of the environment. Args: vectorized (bool): If True, the environment is vectorized. Returns: EnvParams: The environment parameters object. """ if not vectorized: action_space = self.env.action_space observation_space = self.env.observation_space else: action_space = self.env.single_action_space observation_space = self.env.single_observation_space if isinstance(action_space, gym.spaces.Discrete): action_len, act_cont, act_min, act_max = self._get_params_from_discrete(action_space, is_action=True) elif isinstance(action_space, gym.spaces.Box): action_len, act_cont, act_min, act_max = self._get_params_from_box(action_space, is_action=True) elif isinstance(action_space, gym.spaces.Dict): action_len, act_cont, act_min, act_max = self._get_params_from_dict(action_space, is_action=True) else: raise NotImplementedError(f"{action_space} action space is not supported") if isinstance(observation_space, gym.spaces.Discrete): obs_shape, obs_cont, obs_min, obs_max = self._get_params_from_discrete(observation_space) elif isinstance(observation_space, gym.spaces.Box): obs_shape, obs_cont, obs_min, obs_max = self._get_params_from_box(observation_space) else: raise NotImplementedError(f"{observation_space} observation space is not supported") return EnvParams( action_len=action_len, action_continuous=act_cont, action_min=act_min, action_max=act_max, observation_shape=obs_shape, observation_continuous=obs_cont, observation_min=obs_min, observation_max=obs_max, ) @staticmethod def _get_params_from_discrete(space: gym.spaces.Discrete, is_action: bool = False) -> Tuple[tuple | int, bool, int, int]: """ Extracts the environment parameters from a discrete space. Args: space (gym.spaces.Discrete): The space to extract parameters from. Returns: Tuple[tuple, bool, int, int]: tuple containing (space_shape, space_continuous, space_min, space_max) """ # If this is a discrete action space return an integer action length if is_action: space_shape = 1 else: space_shape = (1,) return space_shape, False, space.start, space.n - 1 @staticmethod def _get_params_from_box(space: gym.spaces.Box, is_action: bool = False) -> Tuple[tuple, bool, List[float], List[float]]: """ Extracts the environment parameters from a box space. Args: space (gym.spaces.Box): The space to extract parameters from. Returns: Tuple[tuple, bool, int, int]: tuple containing (space_shape, space_continuous, space_min, space_max) """ space_shape = space.shape # Retun an integer action length for box action spaces if is_action and len(space_shape) == 1: space_shape = space_shape[0] return space_shape, True, space.low.tolist(), space.high.tolist() @staticmethod def _get_params_from_dict(space: gym.spaces.Dict, is_action: bool = False) -> Tuple[tuple, List[bool], List[float], List[float]]: """ Extracts the environment parameters from a dict space by concatenating all subspaces. Args: space (gym.spaces.Dict): The space to extract parameters from. Returns: Tuple[tuple, List[bool], List[float], List[float]]: tuple containing (space_shape, space_continuous, space_min, space_max) """ if is_action: action_lens = [] action_conts = [] action_mins = [] action_maxs = [] for k in space.spaces: subspace = space.spaces[k] if isinstance(subspace, gym.spaces.Discrete): alen, acont, amin, amax = GymnasiumWrapper._get_params_from_discrete(subspace, is_action=True) elif isinstance(subspace, gym.spaces.Box): alen, acont, amin, amax = GymnasiumWrapper._get_params_from_box(subspace, is_action=True) if isinstance(acont, bool): acont = [acont] * alen else: raise NotImplementedError(f"{subspace} action space is not supported in Dict") action_lens.append(alen) action_conts.extend(acont if isinstance(acont, list) else [acont]) action_mins.extend(amin if isinstance(amin, list) else [amin]) action_maxs.extend(amax if isinstance(amax, list) else [amax]) total_action_len = sum(action_lens) return total_action_len, action_conts, action_mins, action_maxs else: raise NotImplementedError("Dict observation spaces are not supported")