Source code for prt_rl.env.adapters.interface

from prt_rl.env.interface import EnvironmentInterface, EnvParams

[docs] class AdapterInterface(EnvironmentInterface): """ Interface class for environment adapters that adapt an environment to a different interface. Args: env (EnvironmentInterface): The environment to adapt """ def __init__(self, env: EnvironmentInterface): self.env = env self.env_params = self._adapt_params(env.get_parameters())
[docs] def get_parameters(self): return self.env_params
[docs] def reset(self, *args, **kwargs): obs, info = self.env.reset(*args, **kwargs) return self._adapt_obs(obs, info), self._adapt_info(None, obs, None, False, info)
[docs] def step(self, action): raw_action = self._adapt_action(action) obs, reward, done, info = self.env.step(raw_action) adapted_info = self._adapt_info(action, obs, reward, done, info) return self._adapt_obs(obs, info), self._adapt_reward(reward, info), done, adapted_info
# The following methods can be overridden by subclasses to adapt the parameters, observations, actions, rewards, and info dictionaries as needed. By default, they return the input unchanged. def _adapt_params(self, params: EnvParams): return params def _adapt_obs(self, obs, info): return obs def _adapt_action(self, action): return action def _adapt_reward(self, reward, info): return reward def _adapt_info(self, action, obs, reward, done, info): return info