Source code for prt_rl.common.recorders

from abc import ABC
import imageio
import numpy as np
from prt_rl.common.buffers import ReplayBuffer


[docs] class Recorder(ABC): def reset(self) -> None: pass
[docs] def record_info(self, info: dict) -> None: """ Records information from the environment, such as rewards or other metrics. This method can be overridden by subclasses if needed. """ pass
[docs] def record_experience(self, experience: dict) -> None: """ Records experience data, such as state, action, reward, and next state. This method can be overridden by subclasses if needed. """ pass
def close(self) -> None: pass
[docs] class GifRecorder(Recorder): """ Captures rgb_array data and creates a gif. Args: filename (str): Filename to save the gif. fps (int): frames per second loop (bool): Whether to loop the GIF after it runs. Defaults to True. """ def __init__(self, filename: str, fps: int = 10, loop: bool = True ) -> None: self.filename = filename self.fps = fps self.loop = loop self.frames = []
[docs] def reset(self): """ Resets the buffer of frames """ self.frames = []
[docs] def record_info(self, info: dict) -> None: if 'rgb_array' in info: # Get the frame from the first environment if there is more than one rgb_frame = info['rgb_array'][0] self._capture_frame(rgb_frame)
def _capture_frame(self, frame: np.ndarray, ) -> None: """ Captures a frame to be saved to the GIF. Args: frame (np.ndarray): Numpy rgb array to be saved with format (H, W, C) """ # Ensure the frame is in the correct format (H, W, C) if frame.ndim == 2: # If the frame is grayscale frame = np.stack([frame] * 3, axis=-1) self.frames.append(frame)
[docs] def close(self) -> None: """ Saves the captured frames as a GIF. Args: filename (str): filename to save GIF to """ if self.loop: num_loops = 0 else: num_loops = 1 imageio.mimsave(self.filename, self.frames, fps=self.fps, loop=num_loops)
[docs] class ExperienceRecorder(Recorder): """ Records experience data such as state, action, reward, and next state. This can be used for training or analysis later. """ def __init__(self, filename: str) -> None: self.filename = filename self.buffer = ReplayBuffer(capacity=1000000) def reset(self) -> None: self.buffer.clear()
[docs] def record_experience(self, experience: dict) -> None: self.buffer.add(experience=experience)
def close(self) -> None: self.buffer.save(self.filename)