Source code for prt_rl.env.adapters.pixel_observation
import torch
from typing import Literal
from prt_rl.env.interface import EnvironmentInterface
from prt_rl.env.adapters.interface import AdapterInterface
[docs]
class PixelObservationAdapter(AdapterInterface):
"""
Adapater takes the 'rgb_array' pixels from the info dictionary and makes that the observation. The original observation is added back into the info dictionary under the 'state' key.
Assumes pixel observation are [H, W, C] numpy arrays of type uint8 and converts to [C, H, W] torch tensors normalized to [0, 1].
Note: This adapter assumes the base environment has render_mode set to "rgb_array".
Args:
env (EnvironmentInterface): The environment to adapt
"""
def __init__(self,
env: EnvironmentInterface,
pixel_type: Literal["uint8", "float32"] = "uint8"
):
self.pixel_type = pixel_type
_, info = env.reset()
self.image_shape = info['rgb_array'][0].shape
super().__init__(env)
def _adapt_params(self, params):
# Update the observation shape with the image shape
params.observation_shape = self.image_shape
params.observation_continuous = True
params.observation_min = 0.0
params.observation_max = 1.0
return params
def _adapt_info(self, action, obs, reward, done, info):
# Remove 'rgb_array' and add 'state' from the observation
new_info = {k: v for k, v in info.items() if k != 'rgb_array'}
new_info['state'] = obs
return new_info
def _adapt_obs(self, obs, info):
pixels = torch.from_numpy(info['rgb_array'][0]).permute(2, 0, 1) # [C, H, W]
pixels = pixels.unsqueeze(0).to(obs.device) # add batch dimension back in
if self.pixel_type == "float32":
pixels = pixels.to(torch.float32) / 255.0 # normalize to [0, 1]
return pixels