Source code for prt_rl.common.policies.human

from enum import Enum, auto
import threading
import torch
from torch import Tensor
from typing import Dict, Tuple, Union

[docs] class GameControllerPolicy: """ The game controller policy allows interactive control of an agent with discrete or continuous actions. For continuous actions, the key_action_map maps a game controller input to an action index rather than a value. For example, 'JOY_RIGHT_X': 1 would map the x direction of the right joystick to action index 1. Notes: You don't want to use blocking with continuous actions because this would result in jerky agents. I also need to consider half joystick moves. For example if th input is speed you don't want to hold a joystick down to go nowhere. I should accept a default value for the actions as well. Args: env_params (EnvParams): environment parameters key_action_map : mapping from key string to action value blocking (bool, optional): Whether the policy blocks at each step. Defaults to True. Raises: ImportError: If inputs is not installed. AssertionError: If a game controller is not found """
[docs] class Key(Enum): BUTTON_A = auto() BUTTON_B = auto() BUTTON_X = auto() BUTTON_Y = auto() BUTTON_LB = auto() BUTTON_LT = auto() BUTTON_RB = auto() BUTTON_RT = auto() BUTTON_START = auto() BUTTON_BACK = auto() BUTTON_DPAD_UP = auto() BUTTON_DPAD_DOWN = auto() BUTTON_DPAD_LEFT = auto() BUTTON_DPAD_RIGHT = auto() BUTTON_JOY_RIGHT = auto() BUTTON_JOY_LEFT = auto() JOYSTICK_LEFT_X = auto() JOYSTICK_LEFT_Y = auto() JOYSTICK_RIGHT_X = auto() JOYSTICK_RIGHT_Y = auto()
# JOYSTICK_LT = auto() # JOYSTICK_RT = auto() EVENT_TYPE_TO_KEY_MAP = { 'BTN_THUMB': Key.BUTTON_A, 'BTN_THUMB2': Key.BUTTON_B, 'BTN_TRIGGER': Key.BUTTON_X, 'BTN_TOP': Key.BUTTON_Y, 'BTN_TOP2': Key.BUTTON_LB, 'BTN_BASE': Key.BUTTON_LT, 'BTN_PINKIE': Key.BUTTON_RB, 'BTN_BASE2': Key.BUTTON_RT, 'BTN_BASE4': Key.BUTTON_START, 'BTN_BASE3': Key.BUTTON_BACK, 'BTN_DPAD_UP': Key.BUTTON_DPAD_UP, 'BTN_DPAD_DOWN': Key.BUTTON_DPAD_DOWN, 'BTN_DPAD_LEFT': Key.BUTTON_DPAD_LEFT, 'BTN_DPAD_RIGHT': Key.BUTTON_DPAD_RIGHT, 'BTN_BASE5': Key.BUTTON_JOY_LEFT, 'BTN_BASE6': Key.BUTTON_JOY_RIGHT, 'ABS_X': Key.JOYSTICK_LEFT_X, 'ABS_Y': Key.JOYSTICK_LEFT_Y, 'ABS_Z': Key.JOYSTICK_RIGHT_X, 'ABS_RZ': Key.JOYSTICK_RIGHT_Y, } def __init__(self, key_action_map: Union[Dict[Key, int], Dict[Key, int | str]], blocking: bool = True, ) -> None: try: import inputs self.inputs = inputs except ImportError as e: raise ImportError( "The 'inputs' library is required for GameController but is not installed. " "Please install it using 'pip install inputs'." ) from e self.key_action_map = key_action_map self.blocking = blocking self.continuous = self.env_params.action_continuous self.joystick_min = -1.0 self.joystick_max = 1.0 # Check if a game controller is found gamepads = self.inputs.devices.gamepads if not gamepads: raise AssertionError("No game controller found") # Start read thread if the policy is non-blocking if not self.blocking: self.listener_thread = None self.running = False self.latest_values = torch.zeros(*self.env_params.action_shape) self.lock = threading.Lock() self._start_listener()
[docs] @torch.no_grad() def act(self, obs: torch.Tensor, deterministic: bool = True) -> Tuple[Tensor, Dict[str, Tensor]]: """ Gets a game controller input and maps it to the action space. Args: obs (TensorDict): A tensor representing the current state of the environment. Returns: A TensorDict with the "action" key added. """ if not deterministic: raise ValueError("GameControllerAgent does not support non-deterministic actions. Set deterministic=True.") assert obs.shape[0] == 1, "GameController only supports batch size 1 for now." # Get the data type for the action values if self.continuous: ttype = torch.float32 else: ttype = torch.int if self.blocking: key = None while key not in self.key_action_map: key = self.EVENT_TYPE_TO_KEY_MAP[self._wait_for_inputs()] action = self.key_action_map[key] if isinstance(action, int): action_val = torch.tensor([self.key_action_map[key]], dtype=ttype) else: raise ValueError(f"Unsupported action {action}") else: # Non-blocking: use the latest key presses # Grab the latest values and updated current # Scale the joystick value to the action range for continuous with self.lock: action_val = self.latest_values.clone() return action_val.unsqueeze(0), {}
def _start_listener(self): """ Starts an event listening thread that captures game controller inputs and updates the latest action values. """ self.running = True def event_loop(): while self.running: events = self.inputs.get_gamepad() for event in events: match event.ev_type: case "Key": pass case "Absolute": joy_name = event.code joy_value = event.state # Convert the joystick name to a Key if joy_name in self.EVENT_TYPE_TO_KEY_MAP.keys(): joy_key = self.EVENT_TYPE_TO_KEY_MAP[joy_name] # Get the action map from the Key if joy_key in self.key_action_map.keys(): action_map = self.key_action_map[joy_key] # Normalize joystick value to [-1, 1] # Note: the X direction neutral value is 127, but the Y direction neutral value is 128 if joy_name == 'ABS_X' or joy_name == 'ABS_Z': norm_joy_value = (joy_value - 127.0) / 127.0 else: norm_joy_value = -(joy_value - 128.0) / 128.0 # Process action and value self._process_joystick(action_map, norm_joy_value) case "Misc": # Ignore MISC messages pass case "Sync": # Ignore Sync messages pass case _: print(f"Unknown key: {event.ev_type}") self.listener_thread = threading.Thread(target=event_loop, daemon=True) self.listener_thread.start() def _process_joystick(self, action_map_values: Union[int, Tuple[int, str]], norm_value: float, ) -> None: """ Updates the latest action values using the normalized joystick value based on the action map parameters. Args: action_map_values (Union[int, Tuple[int, str]]): Action map values. norm_value (float): Normalized joystick value between [-1, 1] """ if isinstance(action_map_values, int): action_index = action_map_values action_param = None elif len(action_map_values) == 2: action_index, action_param = action_map_values else: raise ValueError(f"Unsupported action map {action_map_values}") # Normalized joystick value ranges from [-1.0, 1.0] joystick_min = -1.0 joystick_max = 1.0 if action_param == 'positive': # Change the joystick range from [0, 1] and clip the negative norm values to 0 joystick_min = 0.0 norm_value = max(joystick_min, norm_value) if action_param == 'negative': # Change the joystick range from [-1.0, 0] and clip the positive norm value to 0 joystick_max = 0.0 norm_value = min(joystick_max, norm_value) # Get the action min/max for this index and scale action value from [joy_min, joy_max] to [action_min, action_max] action_min = self.env_params.action_min[action_index] action_max = self.env_params.action_max[action_index] action_value = ((norm_value - joystick_min) / (joystick_max - joystick_min)) * (action_max - action_min) + action_min # Update latest action values with self.lock: self.latest_values[action_index] = action_value def _wait_for_inputs(self) -> str: """ Blocking listener that captures a single event and returns the value. Returns: String value of the Key pressed. """ assert not self.continuous, "Blocking GameController only supports discrete actions." key_val = None while key_val is None: events = self.inputs.get_gamepad() for event in events: match event.ev_type: case "Key": # Only return the action when the key is pressed and ignore the release if event.state == 1: key_val = event.code case "Absolute": # Read the DPAD buttons print(f"Code: {event.code} State: {event.state}") if event.code == 'ABS_HAT0X': if event.state == 1: key_val = "BTN_DPAD_RIGHT" if event.state == -1: key_val = "BTN_DPAD_LEFT" if event.code == 'ABS_HAT0Y': if event.state == 1: key_val = "BTN_DPAD_DOWN" if event.state == -1: key_val = "BTN_DPAD_UP" pass case "Misc": # Ignore MISC messages pass case "Sync": # Ignore Sync messages pass case _: print(f"Unknown key: {event.ev_type}") return key_val
[docs] class KeyboardPolicy: """ The keyboard policy allows interactive control of the agent using keyboard input. Notes: I could modify this to implement "sticky" keys, so in non-blocking the last key pressed stays the action until a new key is pressed. Alternatively, you could set a default value and the action goes back to a default when the key is released. Args: env_params (EnvParams): environment parameters key_action_map (Dict[str, int]): mapping from key string to action value blocking (bool): If blocking is True the simulation will wait for keyboard input at each step (synchronous), otherwise the simulation will not block and use the most up-to-date value (asynchronous). Default is True. Example: from prt_rl.utils.policies import KeyboardPolicy policy = KeyboardPolicy( env_params=env.get_parameters(), key_action_map={ 'up': 0, 'down': 1, 'left': 2, 'right': 3, }, blocking=True ) action_td = policy.get_action(state_td) """ def __init__(self, key_action_map: Dict[str, int], blocking: bool = True, ) -> None: # Check if pynput is installed try: from pynput import keyboard except ImportError as e: raise ImportError( "The 'pynput' library is required for KeyboardPolicy but is not installed. " "Please install it using 'pip install pynput'." ) from e self.keyboard = keyboard self.key_action_map = key_action_map self.blocking = blocking self.latest_key = None self.listener_thread = None if not self.blocking: self._start_listener()
[docs] @torch.no_grad() def act(self, obs: torch.Tensor, deterministic: bool = True ) -> Tuple[Tensor, Dict[str, Tensor]]: """ Gets a keyboard press and maps it to the action space. Args: obs (torch.Tensor): A tensor representing the current state of the environment. deterministic (bool): If True, the policy will not sample random actions. Defaults to True. Returns: torch.Tensor: A tensor with the action value based on the key pressed. """ if not deterministic: raise ValueError("KeyboardAgent does not support non-deterministic actions. Set deterministic=True.") assert obs.shape[0] == 1, "KeyboardPolicy Only supports batch size 1 for now." if self.blocking: key_string = '' # Keep reading keys until a valid key in the map is received while key_string not in self.key_action_map: key_string = self._wait_for_key_press() action_val = self.key_action_map[key_string] else: # Non-blocking: use the latest key press key_string = self.latest_key if key_string not in self.key_action_map: # If no valid key press, use a default action or skip action_val = 0 # Example: default or no-op action else: action_val = self.key_action_map[key_string] self.latest_key = None # Reset the latest key so another key has to be pressed. return torch.tensor([[action_val]]), {}
def _start_listener(self): """ Starts a background thread to listen for key presses. """ def listen_for_keys(): def on_press(key): try: if isinstance(key, self.keyboard.KeyCode): self.latest_key = key.char elif isinstance(key, self.keyboard.Key): self.latest_key = key.name except Exception as e: print(f"Error in key press listener: {e}") with self.keyboard.Listener(on_press=on_press, suppress=True) as listener: listener.join() self.listener_thread = threading.Thread(target=listen_for_keys, daemon=True) self.listener_thread.start() def _wait_for_key_press(self) -> str: """ Blocking method to wait for keyboard press. Returns: str: String name of the pressed key """ # A callback function to handle key presses def on_press(key): nonlocal key_pressed key_pressed = key # Store the pressed key return False # Stop the listener after a key is pressed key_pressed = None # Start the listener in blocking mode # Supressing keys keeps them from being passed on to the rest of the computer with self.keyboard.Listener(on_press=on_press, suppress=True) as listener: listener.join() # Get string value of KeyCodes and special Keys if isinstance(key_pressed, self.keyboard.KeyCode): key_pressed = key_pressed.char elif isinstance(key_pressed, self.keyboard.Key): key_pressed = key_pressed.name else: raise ValueError(f"Unrecognized key pressed type: {type(key_pressed)}") return key_pressed