Source code for prt_rl.common.visualizers

from abc import ABC, abstractmethod
import numpy as np
import pygame


[docs] class Visualizer(ABC): def start(self): pass def stop(self): pass def show(self, frame: np.ndarray) -> None: pass
[docs] class PygameVisualizer(Visualizer): def __init__(self, fps: int = 50, caption: str = 'Visualizer', ) -> None: self.fps = fps self.caption = caption self.clock = None self.window_size = None self.screen = None def start(self): pygame.init() pygame.display.init() pygame.display.set_caption(self.caption) self.clock = pygame.time.Clock() def stop(self): pygame.quit() def show(self, frame: np.ndarray) -> None: if self.window_size is None: height, width, _ = frame.shape self.window_size = (width, height) self.screen = pygame.display.set_mode(self.window_size) # If the frame is grayscale (H, W, 1), convert it to (H, W) if frame.shape[-1] == 1: frame = frame[:, :, 0] # Make a surface from the RGB array surface = pygame.surfarray.make_surface(frame.swapaxes(0, 1)) # Blit the surface onto the screen self.screen.blit(surface, (0, 0)) pygame.event.pump() pygame.display.flip() self.clock.tick(self.fps)