qtable#

Classes#

class prt_rl.common.qtable.QTable(state_dim: int, action_dim: int, batch_size: int = 1, initial_value: float = 0.0, track_visits: bool = False, device: str = 'cpu')[source]#

The Q table implements a matrix of state-action values.

For example, if there are 3 states, 2 actions, and an initial value of 0.1 the Q table will look like:

0 | 1

0 | 0.1 | 0.1

1

0.1

0.1

2

0.1

0.1

Parameters:
  • state_dim (int) – Number of states

  • action_dim (int) – Number of actions

  • batch_size (int) – Batch size (number of environments).

  • initial_value (float) – Initial value for the entire Q table. Default is 0.0.

  • track_visits (bool) – If True, a Visit table will be created to track state-action visits. Default is False.

  • device (str) – Device to use. Default is ‘cpu’.

Example

from prt_rl.utils.qtable import QTable

qtable = QTable(state_dim=3, action_dim=2) qtable.update_q_value(state=1, action=3, q_value=0.1)

get_action_values(state: Tensor) Tensor[source]#

Returns the state action values for a given state.

Parameters:

state (torch.Tensor) – state value to get action values for with shape (# env, 1)

Returns:

action values for given state with shape (# env, # actions)

Return type:

torch.Tensor

get_state_action_value(state: Tensor, action: Tensor) Tensor[source]#

Returns the value for the given state-action pair.

Parameters:
  • state (torch.Tensor) – state value to get the value for with shape (# env, 1)

  • action (torch.Tensor) – action value to get the value for with shape (# env, 1)

Returns:

value for the given state-action pair with shape (# env, 1)

Return type:

torch.Tensor

get_visit_count(state: Tensor, action: Tensor) Tensor[source]#

Returns the number of visits for a given state-action pair.

Parameters:
  • state (torch.Tensor) – state value to get the number of visits for with shape (# env, 1)

  • action (torch.Tensor) – action value to get the number of visits for with shape (# env, 1)

Returns:

number of visits for given state-action pair with shape (# env, 1))

Return type:

torch.Tensor

to(device: str) None[source]#

Moves the Q table to the specified device.

Parameters:

device (str) – Device to move the Q table to.

update_q_value(state: Tensor, action: Tensor, q_value: Tensor) None[source]#

Updates the Q table for a given state-action pair with given q-value. :param state: state value to update the Q table for with shape (# env, 1) :type state: torch.Tensor :param action: action value to update the Q table for with shape (# env, 1) :type action: torch.Tensor :param q_value: q-value to update the Q table for with shape (# env, 1) :type q_value: torch.Tensor

update_visits(state: Tensor, action: Tensor) None[source]#

Updates the Visit table for a given state-action pair.

Parameters:
  • state (torch.Tensor) – state value to update the Visit table for with shape (# env, 1)

  • action (torch.Tensor) – action value to update the Visit table for with shape (# env, 1)