sarsa#

Classes#

SARSAAgent

SARSA (State-Action-Reward-State-Action) on-policy temporal difference control algorithm.

SARSAConfig

Configuration parameters for the SARSA algorithm.

SARSAPolicy

SARSA Policy implementation using a tabular Q-table.

class prt_rl.exact.sarsa.SARSAAgent(policy: SARSAPolicy, config: SARSAConfig, *, device: str = 'cpu')[source]#

SARSA (State-Action-Reward-State-Action) on-policy temporal difference control algorithm.

SARSA is an on-policy reinforcement learning algorithm that learns action-values Q(s,a) by updating them based on the next action actually taken by the current policy. The update rule is: Q(s,a) ← Q(s,a) + α[r + γQ(s’,a’) - Q(s,a)], where (s,a,r,s’,a’) is the SARSA tuple.

Parameters:
  • policy – The SARSA policy containing the Q-table and decision function.

  • config – Configuration parameters including learning rate (alpha) and discount factor (gamma).

  • device – The device to use for computations (“cpu” or “cuda”). Default is “cpu”.

Variables:
  • policy – The SARSA policy instance.

  • config – The configuration parameters.

act(obs, deterministic=False)[source]#

Select an action for the given observation.

Parameters:
  • obs – The current state/observation.

  • deterministic – If True, selects the action with highest Q-value. If False, uses the policy’s decision function for action selection.

Returns:

A tuple containing the selected action and additional information.

train(env: EnvironmentInterface, total_steps: int, schedulers: List[ParameterScheduler] | None = None, logger: Logger | None = None, evaluator: Evaluator | None = None, show_progress: bool = True) None[source]#

Train the SARSA agent in the given environment.

The agent performs on-policy learning by collecting experience one step at a time, computing the next action according to the current policy, and updating Q-values using the SARSA update rule: Q(s,a) ← Q(s,a) + α[r + γQ(s’,a’) - Q(s,a)].

Parameters:
  • env – The environment to train in.

  • total_steps – Total number of training steps to perform.

  • schedulers – Optional list of parameter schedulers to update during training (e.g., for epsilon decay in epsilon-greedy policies).

  • logger – Optional logger for recording training metrics. If None, creates a default logger.

  • evaluator – Optional evaluator for periodic policy evaluation during training.

  • show_progress – If True, displays a progress bar with training metrics.

Returns:

None. Updates the policy’s Q-table in place.

class prt_rl.exact.sarsa.SARSAConfig(gamma: float = 0.99, alpha: float = 0.1)[source]#

Configuration parameters for the SARSA algorithm.

Variables:
  • gamma (float) – Discount factor for future rewards. A value between 0 and 1 that determines how much the agent values future rewards compared to immediate rewards. Default is 0.99.

  • alpha (float) – Learning rate for Q-value updates. Controls how much the Q-values are adjusted based on new experience. Default is 0.1.

class prt_rl.exact.sarsa.SARSAPolicy(qtable: Tensor, decision_function: DecisionFunction)[source]#

SARSA Policy implementation using a tabular Q-table.

This policy stores state-action values in a table and uses a decision function to select actions. It supports both stochastic action selection (during training) and deterministic action selection (during evaluation).

Parameters:
  • qtable – A 2D tensor of shape (num_states, num_actions) containing Q-values for each state-action pair.

  • decision_function – A function that takes action values and returns an action (e.g., epsilon-greedy, softmax).

Raises:

ValueError – If qtable is not a 2D tensor.

Variables:
  • decision_function – The decision function used for action selection.

  • table – The Q-table inherited from TabularPolicy.

act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#

Select an action based on the observation.

Parameters:
  • obs – The current state/observation as a tensor. (B, S)

  • deterministic – If True, selects the action with highest Q-value. If False, uses the decision function for action selection.

Returns:

  • action: The selected action as a tensor. (B, A)

  • info: A dictionary with ‘q_value’ key containing the Q-value of the selected action.

Return type:

A tuple containing

get_action_values(obs: Tensor) Tensor[source]#

Get all action values (Q-values) for a given state.

Parameters:

obs – The state/observation as a tensor.

Returns:

A tensor containing Q-values for all possible actions in the given state.

get_qvalue(obs: Tensor, action: Tensor) Tensor[source]#

Get the Q-value for a specific state-action pair.

Parameters:
  • obs – The state/observation as a tensor.

  • action – The action as a tensor.

Returns:

The Q-value for the given state-action pair.

set_qvalue(state: Tensor, action: Tensor, qval: Tensor)[source]#

Update the Q-value for a specific state-action pair.

Parameters:
  • state – The state as a tensor.

  • action – The action as a tensor.

  • qval – The new Q-value to set for the state-action pair.

snapshot() Dict[str, Any]#

Return a serializable snapshot.

Subclasses can extend this by snap = super().snapshot(); snap[…] = ….