Source code for prt_rl.exact.qlearning

from dataclasses import dataclass, asdict
from pathlib import Path
import torch
from torch import Tensor
from typing import Optional, List, Tuple, Dict
from prt_rl.env.interface import EnvironmentInterface
from prt_rl.agent import Agent
from prt_rl.common.policies import TabularPolicy
from prt_rl.common.decision_functions import DecisionFunction
from prt_rl.common.loggers import Logger
from prt_rl.common.schedulers import ParameterScheduler
from prt_rl.common.progress_bar import ProgressBar
from prt_rl.common.evaluators import Evaluator
from prt_rl.common.collectors import Collector

[docs] @dataclass class QLearningConfig: gamma: float = 0.99 alpha: float = 0.1
[docs] class QLearningPolicy(TabularPolicy): """ Q-Learning Policy implementation using a tabular Q-table. This policy stores state-action values in a table and uses a decision function to select actions. It supports both stochastic action selection (during training) and deterministic action selection (during evaluation). Args: qtable: A 2D tensor of shape (num_states, num_actions) containing Q-values for each state-action pair. decision_function: A function that takes action values and returns an action (e.g., epsilon-greedy, softmax). Raises: ValueError: If qtable is not a 2D tensor. Attributes: decision_function: The decision function used for action selection. table: The Q-table inherited from TabularPolicy. """ def __init__(self, qtable: Tensor, decision_function: DecisionFunction ) -> None: if qtable.dim() != 2: raise ValueError(f"Q-table must be a 2D tensor of shape (num_states, num_actions), but got shape {qtable.shape}.") super().__init__(table=qtable, decision_function=decision_function)
[docs] def act(self, obs: Tensor, deterministic: bool = False) -> Tuple[Tensor, Dict[str, Tensor]]: """ Select an action based on the observation. Args: obs: The current state/observation as a tensor. (B, S) deterministic: If True, selects the action with highest Q-value. If False, uses the decision function for action selection. Returns: A tuple containing: - action: The selected action as a tensor. (B, A) - info: A dictionary with 'q_value' key containing the Q-value of the selected action. """ if obs.dim() == 2 and obs.shape[0] == 1: obs = obs.squeeze(0) # Ensure obs is of shape (state,) for indexing action_values = self.get_action_values(obs) if not deterministic: action = self.decision_function.select_action(action_values) else: action = torch.argmax(action_values, dim=-1, keepdim=True) return action, {}#{'q_value': action_values[action]}
[docs] def get_qvalue(self, obs: Tensor, action: Tensor) -> Tensor: """ Get the Q-value for a specific state-action pair. Args: obs: The state/observation as a tensor. action: The action as a tensor. Returns: The Q-value for the given state-action pair. (1, ) """ return self.table[obs, action]
[docs] def get_action_values(self, obs: Tensor) -> Tensor: """ Get all action values (Q-values) for a given state. Args: obs: The state/observation as a tensor. Returns: A tensor containing Q-values for all possible actions in the given state. (action_dim, ) """ return self.table[obs]
[docs] def set_qvalue(self, state: Tensor, action: Tensor, qval: Tensor): """ Update the Q-value for a specific state-action pair. Args: state: The state as a tensor. (1,) action: The action as a tensor. (1,) qval: The new Q-value to set for the state-action pair. (1,) """ self.table[state, action] = qval
[docs] class QLearningAgent(Agent): r""" Q-Learning trainer. .. math:: Q(s,a) Args: env_params (EnvParams): environment parameters. """ def __init__(self, policy: QLearningPolicy, config: QLearningConfig, *, device: str = "cpu", ) -> None: super().__init__(device=device) self.policy = policy self.config = config
[docs] @torch.no_grad() def act(self, obs, deterministic = False): """ Select an action for the given observation. Args: obs: The current state/observation. deterministic: If True, selects the action with highest Q-value. If False, uses the policy's decision function for action selection. Returns: A tuple containing the selected action and additional information. """ return self.policy.act(obs, deterministic)
[docs] def train(self, env: EnvironmentInterface, total_steps: int, schedulers: Optional[List[ParameterScheduler]] = None, logger: Optional[Logger] = None, evaluator: Optional[Evaluator] = None, show_progress: bool = True ) -> None: """ Train the Q-Learning agent in the given environment. Args: env: The environment to train in. total_steps: Total number of training steps to perform. schedulers: Optional list of parameter schedulers to update during training (e.g., for epsilon decay in epsilon-greedy policies). logger: Optional logger for recording training metrics. If None, creates a default logger. evaluator: Optional evaluator for periodic policy evaluation during training. show_progress: If True, displays a progress bar with training metrics. Returns: None. Updates the policy's Q-table in place. """ logger = logger or Logger() if show_progress: progress_bar = ProgressBar(total_steps=total_steps) num_steps = 0 # Make a collector and there is no need for a buffer because we will use the experience right away collector = Collector(env=env, logger=logger, flatten=True) while num_steps < total_steps: # Update Schedulers if provided if schedulers is not None: for scheduler in schedulers: scheduler.update(current_step=num_steps) experience = collector.collect_experience(policy=self.policy, num_steps=1) # Convert experience from (1, S, A) to (S, A) by squeezing the first dimension experience = {k: v.squeeze(0) for k, v in experience.items()} state = experience['state'] action = experience['action'] reward = experience['reward'] next_state = experience['next_state'] # Compute updated Q-Value qsa = self.policy.get_qvalue(state, action) qmax, _ = torch.max(self.policy.get_action_values(next_state), dim=-1, keepdim=True) qnew = qsa + self.config.alpha * (reward + self.config.gamma * qmax - qsa) deltaq = qnew - qsa # Update Policy self.policy.set_qvalue(state, action, qnew) num_steps += 1 if show_progress: tracker = collector.get_metric_tracker() progress_bar.update(current_step=num_steps, desc=f"Episode Reward: {tracker.last_episode_reward:.2f}, " f"Episode Length: {tracker.last_episode_length}, " f"Delta Q: {deltaq.cpu().item():.4f},") if logger.should_log(num_steps): logger.log_scalar("train/delta_q", deltaq, num_steps) if evaluator is not None: evaluator.evaluate(agent=self.policy, iteration=num_steps) if evaluator is not None: evaluator.close()
def _save_impl(self, path): """ Writes a self-contained checkpoint directory. Layout: path/ agent.pt policy.pt """ path.mkdir(parents=True, exist_ok=True) payload = { "algo": "QLearning", "agent_format_version": 1, "config": asdict(self.config), } torch.save(payload, path / "agent.pt") self.policy.save(path / "policy.pt") @classmethod def load(cls, path: str | Path, map_location: str | torch.device = "cpu") -> "QLearningAgent": p = Path(path) agent_meta = torch.load(p / "agent.pt", map_location=map_location) if agent_meta["algo"] != "QLearning": raise ValueError(f"Loaded agent type {agent_meta['algo']} is not QLearning.") config = QLearningConfig(**agent_meta["config"]) policy = QLearningPolicy.load(p / "policy.pt", map_location=map_location) return cls(policy=policy, config=config)