Source code for prt_rl.common.qtable

import torch


[docs] class QTable: r""" The Q table implements a matrix of state-action values. For example, if there are 3 states, 2 actions, and an initial value of 0.1 the Q table will look like: +------+-------+--------+ | | 0 | 1 | +=======================+ | 0 | 0.1 | 0.1 | +------+-------+--------+ | 1 | 0.1 | 0.1 | +------+-------+--------+ | 2 | 0.1 | 0.1 | +------+-------+--------+ Args: state_dim (int): Number of states action_dim (int): Number of actions batch_size (int): Batch size (number of environments). initial_value (float): Initial value for the entire Q table. Default is 0.0. track_visits (bool): If True, a Visit table will be created to track state-action visits. Default is False. device (str): Device to use. Default is 'cpu'. Example: from prt_rl.utils.qtable import QTable qtable = QTable(state_dim=3, action_dim=2) qtable.update_q_value(state=1, action=3, q_value=0.1) """ def __init__(self, state_dim: int, action_dim: int, batch_size: int = 1, initial_value: float = 0.0, track_visits: bool = False, device: str = 'cpu' ) -> None: self.state_dim = state_dim self.action_dim = action_dim self.batch_size = batch_size self.initial_value = initial_value self.track_visits = track_visits self.device = device self.q_table = torch.zeros((self.batch_size, self.state_dim, self.action_dim), dtype=torch.float32, device=device) + initial_value # Initialize the visit table if they are being kept if self.track_visits: self.visit_table = torch.zeros((self.batch_size, self.state_dim, self.action_dim), dtype=torch.float32, device=device)
[docs] def to(self, device: str) -> None: """ Moves the Q table to the specified device. Args: device (str): Device to move the Q table to. """ self.device = device self.q_table = self.q_table.to(device) if self.track_visits: self.visit_table = self.visit_table.to(device)
[docs] def get_action_values(self, state: torch.Tensor ) -> torch.Tensor: """ Returns the state action values for a given state. Args: state (torch.Tensor): state value to get action values for with shape (# env, 1) Returns: torch.Tensor: action values for given state with shape (# env, # actions) """ state = state.squeeze(-1) return self.q_table[torch.arange(self.q_table.size(0)), state]
[docs] def get_state_action_value(self, state: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: """ Returns the value for the given state-action pair. Args: state (torch.Tensor): state value to get the value for with shape (# env, 1) action (torch.Tensor): action value to get the value for with shape (# env, 1) Returns: torch.Tensor: value for the given state-action pair with shape (# env, 1) """ state = state.squeeze(-1) action = action.squeeze(-1) return self.q_table[torch.arange(self.q_table.size(0)), state, action].unsqueeze(-1)
[docs] def get_visit_count(self, state: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: """ Returns the number of visits for a given state-action pair. Args: state (torch.Tensor): state value to get the number of visits for with shape (# env, 1) action (torch.Tensor): action value to get the number of visits for with shape (# env, 1) Returns: torch.Tensor: number of visits for given state-action pair with shape (# env, 1)) """ state = state.squeeze(-1) action = action.squeeze(-1) return self.visit_table[torch.arange(self.visit_table.size(0)), state, action].unsqueeze(-1)
[docs] def update_q_value(self, state: torch.Tensor, action: torch.Tensor, q_value: torch.Tensor ) -> None: """ Updates the Q table for a given state-action pair with given q-value. Args: state (torch.Tensor): state value to update the Q table for with shape (# env, 1) action (torch.Tensor): action value to update the Q table for with shape (# env, 1) q_value (torch.Tensor): q-value to update the Q table for with shape (# env, 1) """ state = state.squeeze(-1) action = action.squeeze(-1) q_value = q_value.squeeze(-1) # Use advanced indexing to update the q-table self.q_table[torch.arange(self.q_table.size(0)), state, action] = q_value
[docs] def update_visits(self, state: torch.Tensor, action: torch.Tensor ) -> None: """ Updates the Visit table for a given state-action pair. Args: state (torch.Tensor): state value to update the Visit table for with shape (# env, 1) action (torch.Tensor): action value to update the Visit table for with shape (# env, 1) """ state = state.squeeze(-1) action = action.squeeze(-1) self.visit_table[torch.arange(self.visit_table.size(0)), state, action] += 1.0