monte_carlo#

Classes#

MonteCarloAgent

On-policy First Visit Monte Carlo Algorithm

MonteCarloConfig

MonteCarloPolicy

Monte Carlo Policy implementation using a tabular Q-table.

class prt_rl.exact.monte_carlo.MonteCarloAgent(policy: MonteCarloPolicy, config: MonteCarloConfig, *, device: str = 'cpu')[source]#

On-policy First Visit Monte Carlo Algorithm

\[\begin{split}\begin{equation} Q(S_t,A_t) \leftarrow Q(S_t,A_t) + \frac{1}{N}[\sum_{k=0}^{\infty}\gamma^kR_{t+k+1} - Q(S_t,A_t)] \\ q_s \leftarrow q_s + \frac{1}{n}[G_t - q_s] \end{equation}\end{split}\]
act(obs, deterministic=False)[source]#

Select an action for the given observation.

Parameters:
  • obs – The current state/observation.

  • deterministic – If True, selects the action with highest Q-value. If False, uses the policy’s decision function for action selection.

Returns:

A tuple containing the selected action and additional information.

train(env: EnvironmentInterface, total_steps: int, schedulers: List[ParameterScheduler] | None = None, logger: Logger | None = None, evaluator: Evaluator | None = None, show_progress: bool = True) None[source]#

Train the Monte Carlo agent in the given environment.

Parameters:
  • env – The environment to train in.

  • total_steps – Total number of training steps to perform.

  • schedulers – Optional list of parameter schedulers to update during training (e.g., for epsilon decay in epsilon-greedy policies).

  • logger – Optional logger for recording training metrics. If None, creates a default logger.

  • evaluator – Optional evaluator for periodic policy evaluation during training.

  • show_progress – If True, displays a progress bar with training metrics.

Returns:

None. Updates the policy’s Q-table in place.

class prt_rl.exact.monte_carlo.MonteCarloConfig(gamma: float = 0.99)[source]#
class prt_rl.exact.monte_carlo.MonteCarloPolicy(qtable: Tensor, decision_function: DecisionFunction)[source]#

Monte Carlo Policy implementation using a tabular Q-table.

This policy stores state-action values in a table and uses a decision function to select actions. It supports both stochastic action selection (during training) and deterministic action selection (during evaluation).

Parameters:
  • qtable – A 2D tensor of shape (num_states, num_actions) containing Q-values for each state-action pair.

  • decision_function – A function that takes action values and returns an action (e.g., epsilon-greedy, softmax).

Raises:

ValueError – If qtable is not a 2D tensor.

Variables:
  • decision_function – The decision function used for action selection.

  • table – The Q-table inherited from TabularPolicy.

act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#

Select an action based on the observation.

Parameters:
  • obs – The current state/observation as a tensor. (B, S)

  • deterministic – If True, selects the action with highest Q-value. If False, uses the decision function for action selection.

Returns:

  • action: The selected action as a tensor. (B, A)

  • info: A dictionary with ‘q_value’ key containing the Q-value of the selected action.

Return type:

A tuple containing

get_action_values(obs: Tensor) Tensor[source]#

Get all action values (Q-values) for a given state.

Parameters:

obs – The state/observation as a tensor.

Returns:

A tensor containing Q-values for all possible actions in the given state. (action_dim, )

get_qvalue(obs: Tensor, action: Tensor) Tensor[source]#

Get the Q-value for a specific state-action pair.

Parameters:
  • obs – The state/observation as a tensor.

  • action – The action as a tensor.

Returns:

The Q-value for the given state-action pair. (1, )

set_qvalue(obs: Tensor, action: Tensor, qval: Tensor)[source]#

Update the Q-value for a specific state-action pair.

Parameters:
  • obs – The state/observation as a tensor. (1,)

  • action – The action as a tensor. (1,)

  • qval – The new Q-value to set for the state-action pair. (1,)

snapshot() Dict[str, Any]#

Return a serializable snapshot.

Subclasses can extend this by snap = super().snapshot(); snap[…] = ….