qlearning#
Classes#
Q-Learning trainer.
Q-Learning Policy implementation using a tabular Q-table.
- class prt_rl.exact.qlearning.QLearningAgent(policy: QLearningPolicy, config: QLearningConfig, *, device: str = 'cpu')[source]#
Q-Learning trainer.
\[Q(s,a)\]- Parameters:
env_params (EnvParams) – environment parameters.
- act(obs, deterministic=False)[source]#
Select an action for the given observation.
- Parameters:
obs – The current state/observation.
deterministic – If True, selects the action with highest Q-value. If False, uses the policy’s decision function for action selection.
- Returns:
A tuple containing the selected action and additional information.
- train(env: EnvironmentInterface, total_steps: int, schedulers: List[ParameterScheduler] | None = None, logger: Logger | None = None, evaluator: Evaluator | None = None, show_progress: bool = True) None[source]#
Train the Q-Learning agent in the given environment.
- Parameters:
env – The environment to train in.
total_steps – Total number of training steps to perform.
schedulers – Optional list of parameter schedulers to update during training (e.g., for epsilon decay in epsilon-greedy policies).
logger – Optional logger for recording training metrics. If None, creates a default logger.
evaluator – Optional evaluator for periodic policy evaluation during training.
show_progress – If True, displays a progress bar with training metrics.
- Returns:
None. Updates the policy’s Q-table in place.
- class prt_rl.exact.qlearning.QLearningPolicy(qtable: Tensor, decision_function: DecisionFunction)[source]#
Q-Learning Policy implementation using a tabular Q-table.
This policy stores state-action values in a table and uses a decision function to select actions. It supports both stochastic action selection (during training) and deterministic action selection (during evaluation).
- Parameters:
qtable – A 2D tensor of shape (num_states, num_actions) containing Q-values for each state-action pair.
decision_function – A function that takes action values and returns an action (e.g., epsilon-greedy, softmax).
- Raises:
ValueError – If qtable is not a 2D tensor.
- Variables:
decision_function – The decision function used for action selection.
table – The Q-table inherited from TabularPolicy.
- act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#
Select an action based on the observation.
- Parameters:
obs – The current state/observation as a tensor. (B, S)
deterministic – If True, selects the action with highest Q-value. If False, uses the decision function for action selection.
- Returns:
action: The selected action as a tensor. (B, A)
info: A dictionary with ‘q_value’ key containing the Q-value of the selected action.
- Return type:
A tuple containing
- get_action_values(obs: Tensor) Tensor[source]#
Get all action values (Q-values) for a given state.
- Parameters:
obs – The state/observation as a tensor.
- Returns:
A tensor containing Q-values for all possible actions in the given state. (action_dim, )
- get_qvalue(obs: Tensor, action: Tensor) Tensor[source]#
Get the Q-value for a specific state-action pair.
- Parameters:
obs – The state/observation as a tensor.
action – The action as a tensor.
- Returns:
The Q-value for the given state-action pair. (1, )