qtable#
Classes#
- class prt_rl.common.qtable.QTable(state_dim: int, action_dim: int, batch_size: int = 1, initial_value: float = 0.0, track_visits: bool = False, device: str = 'cpu')[source]#
The Q table implements a matrix of state-action values.
For example, if there are 3 states, 2 actions, and an initial value of 0.1 the Q table will look like:
0 | 10 | 0.1 | 0.1
1
0.1
0.1
2
0.1
0.1
- Parameters:
state_dim (int) – Number of states
action_dim (int) – Number of actions
batch_size (int) – Batch size (number of environments).
initial_value (float) – Initial value for the entire Q table. Default is 0.0.
track_visits (bool) – If True, a Visit table will be created to track state-action visits. Default is False.
device (str) – Device to use. Default is ‘cpu’.
Example
from prt_rl.utils.qtable import QTable
qtable = QTable(state_dim=3, action_dim=2) qtable.update_q_value(state=1, action=3, q_value=0.1)
- get_action_values(state: Tensor) Tensor[source]#
Returns the state action values for a given state.
- Parameters:
state (torch.Tensor) – state value to get action values for with shape (# env, 1)
- Returns:
action values for given state with shape (# env, # actions)
- Return type:
- get_state_action_value(state: Tensor, action: Tensor) Tensor[source]#
Returns the value for the given state-action pair.
- Parameters:
state (torch.Tensor) – state value to get the value for with shape (# env, 1)
action (torch.Tensor) – action value to get the value for with shape (# env, 1)
- Returns:
value for the given state-action pair with shape (# env, 1)
- Return type:
- get_visit_count(state: Tensor, action: Tensor) Tensor[source]#
Returns the number of visits for a given state-action pair.
- Parameters:
state (torch.Tensor) – state value to get the number of visits for with shape (# env, 1)
action (torch.Tensor) – action value to get the number of visits for with shape (# env, 1)
- Returns:
number of visits for given state-action pair with shape (# env, 1))
- Return type:
- to(device: str) None[source]#
Moves the Q table to the specified device.
- Parameters:
device (str) – Device to move the Q table to.
- update_q_value(state: Tensor, action: Tensor, q_value: Tensor) None[source]#
Updates the Q table for a given state-action pair with given q-value. :param state: state value to update the Q table for with shape (# env, 1) :type state: torch.Tensor :param action: action value to update the Q table for with shape (# env, 1) :type action: torch.Tensor :param q_value: q-value to update the Q table for with shape (# env, 1) :type q_value: torch.Tensor
- update_visits(state: Tensor, action: Tensor) None[source]#
Updates the Visit table for a given state-action pair.
- Parameters:
state (torch.Tensor) – state value to update the Visit table for with shape (# env, 1)
action (torch.Tensor) – action value to update the Visit table for with shape (# env, 1)