Source code for prt_rl.env.adapters.historical_observation

import torch
from prt_rl.env.interface import EnvironmentInterface
from prt_rl.env.adapters.interface import AdapterInterface


[docs] class HistoricalObservationAdapter(AdapterInterface): """ Adapter that augments observations with a fixed-length observation history and optional action history. Args: env (EnvironmentInterface): The environment to adapt num_steps (int): Number of observations to stack in the augmented observation. include_actions (bool): If True, include previous actions between stacked observations. append_last_action (bool): If True and include_actions is True, append the most recent action to the end of the observation stack. Example (num_steps=3): - False: [o_{t-2}, o_{t-1}, o_t] - True: [o_{t-2}, a_{t-2}, o_{t-1}, a_{t-1}, o_t] - True + append_last_action=True: [o_{t-2}, a_{t-2}, o_{t-1}, a_{t-1}, o_t, a_{t-1}] """ def __init__(self, env: EnvironmentInterface, num_steps: int = 4, include_actions: bool = True, append_last_action: bool = False, ) -> None: params = env.get_parameters() if len(params.observation_shape) != 1: raise ValueError("HistoricalObservationAdapter only supports environments with 1D observation spaces.") if num_steps < 1: raise ValueError("num_steps must be >= 1.") if include_actions and not params.action_continuous: raise ValueError("HistoricalObservationAdapter with include_actions=True only supports continuous action spaces.") self.num_steps = num_steps self.include_actions = include_actions self.append_last_action = append_last_action self.action_dim = params.action_len self.observation_history = [] self.action_history = [] self.last_action = None super().__init__(env) def _adapt_params(self, params): original_obs_dim = params.observation_shape[0] num_action_slots = max(self.num_steps - 1, 0) if self.include_actions else 0 if self.include_actions and self.append_last_action: num_action_slots += 1 adapted_obs_dim = original_obs_dim * self.num_steps + self.action_dim * num_action_slots params.observation_shape = (adapted_obs_dim,) if isinstance(params.observation_min, list): observation_min = params.observation_min else: observation_min = [params.observation_min] * original_obs_dim if isinstance(params.observation_max, list): observation_max = params.observation_max else: observation_max = [params.observation_max] * original_obs_dim if len(observation_min) != original_obs_dim: raise ValueError(f"Expected observation_min length {original_obs_dim}, got {len(observation_min)}.") if len(observation_max) != original_obs_dim: raise ValueError(f"Expected observation_max length {original_obs_dim}, got {len(observation_max)}.") if self.include_actions: if isinstance(params.action_min, list): action_min = params.action_min else: action_min = [params.action_min] * self.action_dim if isinstance(params.action_max, list): action_max = params.action_max else: action_max = [params.action_max] * self.action_dim if len(action_min) != self.action_dim: raise ValueError(f"Expected action_min length {self.action_dim}, got {len(action_min)}.") if len(action_max) != self.action_dim: raise ValueError(f"Expected action_max length {self.action_dim}, got {len(action_max)}.") adapted_min = [] adapted_max = [] for idx in range(self.num_steps): adapted_min.extend(observation_min) adapted_max.extend(observation_max) if idx < self.num_steps - 1: adapted_min.extend(action_min) adapted_max.extend(action_max) if self.append_last_action: adapted_min.extend(action_min) adapted_max.extend(action_max) else: adapted_min = observation_min * self.num_steps adapted_max = observation_max * self.num_steps params.observation_min = adapted_min params.observation_max = adapted_max return params
[docs] def reset(self, *args, **kwargs): # Clear temporal buffers at episode reset. self.observation_history = [] self.action_history = [] self.last_action = None return super().reset(*args, **kwargs)
def _adapt_action(self, action): """Store the previous action.""" if self.include_actions: self.last_action = action self.action_history.append(action) if len(self.action_history) > self.num_steps - 1: self.action_history.pop(0) return super()._adapt_action(action) def _adapt_obs(self, obs, info): """Store the current observation and return the stacked history.""" self.observation_history.append(obs) if len(self.observation_history) > self.num_steps: self.observation_history.pop(0) padded_obs_history = [torch.zeros_like(obs)] * (self.num_steps - len(self.observation_history)) + self.observation_history if not self.include_actions: return torch.cat(padded_obs_history, dim=-1) batch_size = obs.shape[0] padded_action_history = [torch.zeros((batch_size, self.action_dim), dtype=obs.dtype, device=obs.device)] * (self.num_steps - 1 - len(self.action_history)) + self.action_history parts = [] for idx, hist_obs in enumerate(padded_obs_history): parts.append(hist_obs) if idx < self.num_steps - 1: parts.append(padded_action_history[idx]) if self.append_last_action: if self.last_action is None: last_action = torch.zeros((batch_size, self.action_dim), dtype=obs.dtype, device=obs.device) else: last_action = self.last_action parts.append(last_action) augmented_obs = torch.cat(parts, dim=-1) return augmented_obs