Source code for prt_rl.env.adapters.action_augmented_observation

import torch
from prt_rl.env.interface import EnvironmentInterface
from prt_rl.env.adapters.interface import AdapterInterface

[docs] class ActionAugmentedObservationAdapter(AdapterInterface): """ Adapter that augments the observation with the previous action taken. The observation is concatenated with the previous action along the last dimension. Args: env (EnvironmentInterface): The environment to adapt """ def __init__(self, env: EnvironmentInterface ) -> None: params = env.get_parameters() if not params.action_continuous: raise ValueError("ActionAugmentedObservationAdapter only supports environments with continuous action spaces.") if len(params.observation_shape) != 1: raise ValueError("ActionAugmentedObservationAdapter only supports environments with 1D observation spaces.") self.action_dim = params.action_len self.previous_action = None super().__init__(env) def _adapt_params(self, params): # Update the observation shape to include action dimensions original_obs_dim = params.observation_shape[0] params.observation_shape = (original_obs_dim + self.action_dim,) if isinstance(params.observation_min, list): observation_min = params.observation_min else: observation_min = [params.observation_min] * original_obs_dim if isinstance(params.observation_max, list): observation_max = params.observation_max else: observation_max = [params.observation_max] * original_obs_dim if isinstance(params.action_min, list): action_min = params.action_min else: action_min = [params.action_min] * self.action_dim if isinstance(params.action_max, list): action_max = params.action_max else: action_max = [params.action_max] * self.action_dim params.observation_min = observation_min + action_min params.observation_max = observation_max + action_max return params def _adapt_action(self, action): """Store the previous action""" self.previous_action = action return super()._adapt_action(action) def _adapt_obs(self, obs, info): # Concatenate the previous action to the observation batch_size = obs.shape[0] if self.previous_action is None: # If no previous action, use zeros self.previous_action = torch.zeros((batch_size, self.action_dim), device=obs.device) augmented_obs = torch.cat([obs, self.previous_action], dim=-1) return augmented_obs