sarsa#
Classes#
SARSA (State-Action-Reward-State-Action) on-policy temporal difference control algorithm.
Configuration parameters for the SARSA algorithm.
SARSA Policy implementation using a tabular Q-table.
- class prt_rl.exact.sarsa.SARSAAgent(policy: SARSAPolicy, config: SARSAConfig, *, device: str = 'cpu')[source]#
SARSA (State-Action-Reward-State-Action) on-policy temporal difference control algorithm.
SARSA is an on-policy reinforcement learning algorithm that learns action-values Q(s,a) by updating them based on the next action actually taken by the current policy. The update rule is: Q(s,a) ← Q(s,a) + α[r + γQ(s’,a’) - Q(s,a)], where (s,a,r,s’,a’) is the SARSA tuple.
- Parameters:
policy – The SARSA policy containing the Q-table and decision function.
config – Configuration parameters including learning rate (alpha) and discount factor (gamma).
device – The device to use for computations (“cpu” or “cuda”). Default is “cpu”.
- Variables:
policy – The SARSA policy instance.
config – The configuration parameters.
- act(obs, deterministic=False)[source]#
Select an action for the given observation.
- Parameters:
obs – The current state/observation.
deterministic – If True, selects the action with highest Q-value. If False, uses the policy’s decision function for action selection.
- Returns:
A tuple containing the selected action and additional information.
- train(env: EnvironmentInterface, total_steps: int, schedulers: List[ParameterScheduler] | None = None, logger: Logger | None = None, evaluator: Evaluator | None = None, show_progress: bool = True) None[source]#
Train the SARSA agent in the given environment.
The agent performs on-policy learning by collecting experience one step at a time, computing the next action according to the current policy, and updating Q-values using the SARSA update rule: Q(s,a) ← Q(s,a) + α[r + γQ(s’,a’) - Q(s,a)].
- Parameters:
env – The environment to train in.
total_steps – Total number of training steps to perform.
schedulers – Optional list of parameter schedulers to update during training (e.g., for epsilon decay in epsilon-greedy policies).
logger – Optional logger for recording training metrics. If None, creates a default logger.
evaluator – Optional evaluator for periodic policy evaluation during training.
show_progress – If True, displays a progress bar with training metrics.
- Returns:
None. Updates the policy’s Q-table in place.
- class prt_rl.exact.sarsa.SARSAConfig(gamma: float = 0.99, alpha: float = 0.1)[source]#
Configuration parameters for the SARSA algorithm.
- Variables:
gamma (float) – Discount factor for future rewards. A value between 0 and 1 that determines how much the agent values future rewards compared to immediate rewards. Default is 0.99.
alpha (float) – Learning rate for Q-value updates. Controls how much the Q-values are adjusted based on new experience. Default is 0.1.
- class prt_rl.exact.sarsa.SARSAPolicy(qtable: Tensor, decision_function: DecisionFunction)[source]#
SARSA Policy implementation using a tabular Q-table.
This policy stores state-action values in a table and uses a decision function to select actions. It supports both stochastic action selection (during training) and deterministic action selection (during evaluation).
- Parameters:
qtable – A 2D tensor of shape (num_states, num_actions) containing Q-values for each state-action pair.
decision_function – A function that takes action values and returns an action (e.g., epsilon-greedy, softmax).
- Raises:
ValueError – If qtable is not a 2D tensor.
- Variables:
decision_function – The decision function used for action selection.
table – The Q-table inherited from TabularPolicy.
- act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#
Select an action based on the observation.
- Parameters:
obs – The current state/observation as a tensor. (B, S)
deterministic – If True, selects the action with highest Q-value. If False, uses the decision function for action selection.
- Returns:
action: The selected action as a tensor. (B, A)
info: A dictionary with ‘q_value’ key containing the Q-value of the selected action.
- Return type:
A tuple containing
- get_action_values(obs: Tensor) Tensor[source]#
Get all action values (Q-values) for a given state.
- Parameters:
obs – The state/observation as a tensor.
- Returns:
A tensor containing Q-values for all possible actions in the given state.
- get_qvalue(obs: Tensor, action: Tensor) Tensor[source]#
Get the Q-value for a specific state-action pair.
- Parameters:
obs – The state/observation as a tensor.
action – The action as a tensor.
- Returns:
The Q-value for the given state-action pair.