Source code for prt_sim.jhu.robot_game

import numpy as np
import os
from typing import Optional, Tuple
from prt_sim.jhu.base import BaseEnvironment
from prt_sim.common.grid_rendering import GridworldRender


[docs] class RobotGame(BaseEnvironment): """ Robot Game is a discrete grid world navigated by Fred the robot. .. image:: /_static/robot_game.png :alt: Robot Game :width: 100% :align: center **Action space**: integer representing a discrete action described in the table below +-----+--------+ | Num | Action | +=====+========+ | 0 | up | +-----+--------+ | 1 | down | +-----+--------+ | 2 | left | +-----+--------+ | 3 | right | +-----+--------+ **Observation space**: integer representing the current grid space **Reward**: +25 for reaching goal, -25 for falling in pit, -1 for every other location Examples: """ metadata = { "render_modes": ["human", "rgb_array"], "render_fps": 5 } UP = 0 DOWN = 1 LEFT = 2 RIGHT = 3 def __init__(self, render_mode: Optional[str] = "rgb_array", ) -> None: self.render_mode = render_mode self.grid_width = 4 self.grid_height = 3 self.world = np.zeros((self.grid_width, self.grid_height)) self.current_position = np.array([0, 2]) self.window_surface = None self.clock = None self.gridworld_render = GridworldRender( grid_width=self.grid_width, grid_height=self.grid_height, window_size=(800, 800), render_mode=self.render_mode, render_fps=self.metadata['render_fps'], agent_icons={ 'fred': os.path.join(os.path.dirname(__file__), 'icons/fred.png'), 'outlet': os.path.join(os.path.dirname(__file__), 'icons/outlet.png'), 'pit': os.path.join(os.path.dirname(__file__), 'icons/pit.png'), 'obstacle': os.path.join(os.path.dirname(__file__), 'icons/obstacle.png'), }, window_title='Fred Robot Game', background_color=(198, 236, 254) ) self.agent_positions = { 'outlet': np.array([3, 0]), 'pit': np.array([3, 1]), 'obstacle': np.array([2, 1]), 'fred': self.current_position, }
[docs] def get_number_of_states(self) -> int: """ Returns the number of states in the world. Returns: int: number of states """ return 11
[docs] def get_number_of_actions(self) -> int: """ Returns the number of actions in the world. Returns: int: number of actions """ return 4
[docs] def reset(self, seed: Optional[int] = None, randomize_start: Optional[bool] = False ) -> int: """ Resets the world to initial state. If initial_state is None, the robot is initialized to state 8. Args: seed (int, optional): Random seed. Defaults to None. randomize_start (bool, optional): Whether to randomize the starting state. Not all environments will support this. Defaults to False. Returns: int: current state value """ assert not randomize_start, "Randomizing the start is not supported" self.current_position = np.array([0, 2]) return self.get_state()
[docs] def execute_action(self, action: int ) -> Tuple[int, float, bool]: """ Executes the action and a step of the world. Args: action (int): robot action to take Returns: tuple: a tuple of the (state, reward, done) """ done = False reward = -1 # Handle action new_position = self.current_position.copy() if action == self.UP: new_position[1] -= 1 elif action == self.DOWN: new_position[1] += 1 elif action == self.LEFT: new_position[0] -= 1 elif action == self.RIGHT: new_position[0] += 1 else: raise Exception(f"Invalid action {action} must be [0..3]") if self._check_if_valid_action(new_position): self.current_position = new_position # Done States if self.get_state() == 3: reward = 25 done = True elif self.get_state() == 6: reward = -25 done = True return self.get_state(), float(reward), done
[docs] def get_state(self) -> int: """ Returns the current world state, which is the location of the robot. Returns: int: current location of the robot """ state = self.current_position[1] * self.grid_width + self.current_position[0] if state > 5: state -= 1 return state
def _check_if_valid_action(self, position: list[int]) -> bool: """ Checks if the location of the robot after taking an action is valid. Args: position (list[int]): location of the robot to check Returns: bool: True if the location is valid, False otherwise """ # Agent is trying to go off the top or bottom of the grid if position[1] < 0 or position[1] >= self.grid_height: return False # Agent is trying to get off the left or right part of the grid if position[0] < 0 or position[0] >= self.grid_width: return False # This is a blank location on the grid if np.array_equal(position, np.array([2, 1])): return False return True
[docs] def render(self): self.agent_positions['fred'] = self.current_position if self.render_mode == 'human': self.gridworld_render.render(self.agent_positions) elif self.render_mode == 'rgb_array': return self.gridworld_render.render(self.agent_positions)
if __name__ == '__main__': env = RobotGame() env.reset() env.render() for _ in range(20): _, _, done = env.execute_action(np.random.randint(4)) env.render() if done: break