Source code for prt_sim.common.grid_rendering

import numpy as np
import os
import pygame
from typing import Any, Dict, Tuple


[docs] class GridworldRender: """ This class renders a grid world using icons for the agents in the world. It is assumed that position (0, 0) is the top left of the window, and x is positive right and y is positive down. Args: grid_width (int): width of the grid grid_height (int): height of the grid window_size (Tuple[int, int]): size of the window in pixels agent_icons (Dict[str, str]): dictionary of agent names and icon file locations render_fps (int): frames per second for rendering window_title (str): title of the window Examples: """ def __init__(self, grid_width: int, grid_height: int, window_size: Tuple[int, int], agent_icons: Dict[str, str], render_mode: str = "human", render_fps: int = 5, window_title: str = "Gridworld", background_color: Tuple[int, int, int] = (255, 255, 255), ) -> None: self.grid_width = grid_width self.grid_height = grid_height self.window_size = window_size self.agent_icons = self._load_icons(agent_icons) self.render_mode = render_mode self.render_fps = render_fps self.window_title = window_title self.background_color = background_color self.cell_size = (self.window_size[0] // self.grid_width, self.window_size[1] // self.grid_height) self.window_surface = None self.clock = None
[docs] def close(self) -> None: """ Safely closes pygame and the display. """ if self.window_surface is not None: pygame.quit()
[docs] def render(self, agent_positions: Dict[str, np.ndarray], ): """ Renders the grid world and the agent icons from dictionary of agent names and positions. Args: agent_positions (Dict[str, np.ndarray]): dictionary of agent names and numpy array of (x,y) grid positions """ if self.window_surface is None: if self.render_mode == "human": pygame.init() pygame.display.init() pygame.display.set_caption(self.window_title) self.window_surface = pygame.display.set_mode(self.window_size) elif self.render_mode == "rgb_array": self.window_surface = pygame.Surface(self.window_size) if self.clock is None: self.clock = pygame.time.Clock() self._draw_grid() self._draw_agent_icons(agent_positions) if self.render_mode == "human": pygame.event.pump() pygame.display.flip() self.clock.tick(self.render_fps) elif self.render_mode == "rgb_array": return np.transpose( np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2) )
def _load_icons(self, agent_icons: Dict[str, str]) -> Dict[str, Any]: """ Loads icons from the dictionary of agent names and icon filenames. Args: agent_icons (Dict[str, str]): dictionary of agent names and icon file locations Returns: Dict[str, Any]: dictionary of agent names and pygame Surfaces """ loaded_icons = {} for agent, icon in agent_icons.items(): icon_file = os.path.join(os.path.dirname(__file__), icon) loaded_icons[agent] = pygame.image.load(icon_file) return loaded_icons def _draw_grid(self) -> None: """ Draws the grid world boundary and cells """ # Draw white rectangular background self.window_surface.fill(self.background_color) for x in range(self.grid_width): for y in range(self.grid_height): pos = (x * self.cell_size[0], y * self.cell_size[1]) rect = (*pos, *self.cell_size) pygame.draw.rect(self.window_surface, (0, 0, 0), rect, 2) def _draw_agent_icons(self, agent_positions: Dict[str, np.ndarray]) -> None: """ Scales the agent icon to match the cell size and draws them in the grid world. Args: agent_positions (Dict[str, np.ndarray]): dictionary of agent names and numpy array of (x,y) grid positions """ for agent, position in agent_positions.items(): if position is not None: x = position[0] y = position[1] pos = (x * self.cell_size[0], y * self.cell_size[1]) img = pygame.transform.scale(self.agent_icons[agent], self.cell_size) self.window_surface.blit(img, pos)