Source code for prt_rl.common.buffers

from abc import ABC, abstractmethod
from pathlib import Path
import numpy as np
import torch
from typing import Dict, Tuple, Optional, Union


[docs] class BaseBuffer(ABC): def __init__(self, capacity: int, device: str = 'cpu' ) -> None: self.capacity = capacity self.device = torch.device(device) self.size = 0 self.pos = 0
[docs] def get_size(self) -> int: """ Returns the current number of elements in the replay buffer. Returns: int: The current size of the replay buffer. """ return self.size
def __len__(self) -> int: """ Returns the current number of elements in the replay buffer. Returns: int: The current size of the replay buffer. """ return self.size
[docs] @abstractmethod def add(self, experience: Dict[str, torch.Tensor]) -> None: """ Adds a new experience to the replay buffer. Args: experience (Dict[str, torch.Tensor]): A dictionary containing the experience data. """ raise NotImplementedError
[docs] @abstractmethod def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: """ Samples a batch of experiences from the replay buffer. Args: batch_size (int): The number of samples to draw from the buffer. Returns: Dict[str, torch.Tensor]: A dictionary containing the sampled experiences. """ raise NotImplementedError
[docs] @abstractmethod def clear(self) -> None: """ Clears the replay buffer, resetting its size and position. """ raise NotImplementedError
[docs] class ReplayBuffer(BaseBuffer): """ A circular replay buffer that overwrites old experiences when full. Args: capacity (int): The maximum number of experiences to store. device (torch.device): The device to store the buffer on (default: CPU). """ def __init__(self, capacity: int, device: torch.device = torch.device("cpu")): super().__init__(capacity, device) self.buffer = {} self.initialized = False self.metadata: Optional[Dict] = None def _init_storage(self, experience: Dict[str, torch.Tensor]) -> None: """ Initializes the storage for the replay buffer based on the first transition. Args: experience (Dict[str, torch.Tensor]): A dictionary containing the transition data. """ for k, v in experience.items(): shape = (self.capacity,) + v.shape[1:] # Skip batch dim self.buffer[k] = torch.zeros(shape, dtype=v.dtype, device=self.device) self.initialized = True
[docs] def add(self, experience: Dict[str, torch.Tensor] ) -> None: """ Adds a new transition to the replay buffer. Args: experience (Dict[str, torch.Tensor]): A dictionary containing the transition data. """ if not self.initialized: self._init_storage(experience) batch_size = next(iter(experience.values())).shape[0] insert_end = self.pos + batch_size if insert_end <= self.capacity: # One contiguous block idx = slice(self.pos, insert_end) for k, v in experience.items(): self.buffer[k][idx] = v.to(self.device) else: # Wrap-around: split into two writes first_len = self.capacity - self.pos second_len = batch_size - first_len for k, v in experience.items(): self.buffer[k][self.pos:] = v[:first_len].to(self.device) self.buffer[k][:second_len] = v[first_len:].to(self.device) self.pos = (self.pos + batch_size) % self.capacity self.size = min(self.size + batch_size, self.capacity)
[docs] def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: """ Samples a batch of transitions from the replay buffer. Args: batch_size (int): The number of samples to draw from the buffer. Returns: Dict[str, torch.Tensor]: A dictionary containing the sampled transitions. """ if self.size < batch_size: raise ValueError("Not enough samples in buffer to sample.") indices = torch.randint(0, self.size, (batch_size,), device=self.device) return {k: v[indices] for k, v in self.buffer.items()}
[docs] def resize(self, new_capacity: int): """ Expands the buffer to a new capacity while preserving existing data. Args: new_capacity (int): The new buffer capacity. """ if new_capacity <= self.capacity: raise ValueError("New capacity must be greater than current capacity.") new_buffer = {} for k, v in self.buffer.items(): new_shape = (new_capacity,) + v.shape[1:] new_tensor = torch.zeros(new_shape, dtype=v.dtype, device=self.device) if self.pos >= self.size: # No wrap-around new_tensor[:self.size] = v[:self.size] else: # Wrap-around logic new_tensor[:self.capacity - self.pos] = v[self.pos:] new_tensor[self.capacity - self.pos:self.size] = v[:self.pos] new_buffer[k] = new_tensor self.buffer = new_buffer self.capacity = new_capacity self.pos = self.size # Next write after last element
[docs] def get_batches(self, batch_size: int): """ Yields shuffled mini-batches from the buffer. """ if self.size == 0: return indices = torch.randperm(self.size, device=self.device) for i in range(0, self.size, batch_size): idx = indices[i:i + batch_size] yield {k: v[idx] for k, v in self.buffer.items()}
[docs] def clear(self) -> None: """ Clears the replay buffer, resetting its state. """ self.size = 0 self.pos = 0 self.buffer = {} self.initialized = False
[docs] def save(self, path: Union[str, Path]) -> None: """ Saves the replay buffer to a file. Args: path (str | Path): Path to the file where the buffer will be saved. """ path = Path(path) payload = { 'buffer': self.buffer, 'size': self.size, 'pos': self.pos, 'capacity': self.capacity, } if self.metadata is not None: payload['metadata'] = self.metadata torch.save(payload, path)
[docs] def get_metadata(self) -> Optional[Dict]: """ Returns the metadata stored with the replay buffer, if any. """ return self.metadata
[docs] def set_metadata(self, metadata: Dict) -> None: """ Sets the metadata for the replay buffer. Args: metadata (Dict): A dictionary containing metadata information. """ self.metadata = metadata
[docs] @classmethod def load(cls, path: Union[str, Path], device: str = "cpu") -> "ReplayBuffer": """ Loads a replay buffer from a file. Args: path (str | Path): Path to the saved buffer file. device (str): Device to load the buffer to. Defaults to "cpu". Returns: ReplayBuffer: A ReplayBuffer instance with restored data. """ dev = torch.device(device) data = torch.load(Path(path), map_location=dev, weights_only=False) obj = cls(capacity=data['capacity'], device=dev) obj.buffer = {k: v.to(dev) for k, v in data['buffer'].items()} obj.size = data['size'] obj.pos = data['pos'] obj.initialized = True if 'metadata' in data: obj.metadata = data['metadata'] return obj
[docs] class SumTree: """ A binary sum tree for efficient sampling of elements proportional to their priority. The SumTree is a binary tree where each parent node stores the sum of its child nodes. It's particularly useful for: - Efficient sampling from a discrete probability distribution: - You can sample an index i proportional to a weight p_i in O(log N) time. - Efficient dynamic updates of weights (e.g., priority values) while maintaining cumulative structure. **Key Applications:** - Prioritized Experience Replay (PER): Sample transitions with probability ∝ priority. - Importance Sampling in any algorithm that requires sampling proportional to some non-uniform, changeable weights. - Event scheduling in simulations: Where some events happen more frequently than others. Attributes: capacity (int): Maximum number of elements the tree can hold. tree (np.ndarray): Binary tree storing priorities. data_pointer (int): Current position to insert new priority. """ def __init__(self, capacity: int) -> None: self.capacity = capacity self.tree = np.zeros(2 * capacity - 1, dtype=np.float32) self.data_pointer = 0
[docs] def add(self, priority: float) -> None: """ Add a new priority to the sum tree. Args: priority (float): Priority value to insert. """ tree_idx = self.data_pointer + self.capacity - 1 self.update(tree_idx, priority) self.data_pointer = (self.data_pointer + 1) % self.capacity
[docs] def update(self, tree_idx: int, priority: float) -> None: """ Update the priority at a specific tree index and propagate the change. Args: tree_idx (int): Index in the tree to update. priority (float): New priority value. """ change = priority - self.tree[tree_idx] self.tree[tree_idx] = priority while tree_idx != 0: tree_idx = (tree_idx - 1) // 2 self.tree[tree_idx] += change
[docs] def get_leaf(self, value: float) -> Tuple[int, float, int]: """ Traverse the tree to find the leaf node corresponding to the sample value. Args: value (float): A sample value in [0, total_priority). Returns: Tuple[int, float, int]: (tree index, priority, data index) """ parent_idx = 0 while True: left_idx = 2 * parent_idx + 1 right_idx = left_idx + 1 if left_idx >= len(self.tree): leaf_idx = parent_idx break if value <= self.tree[left_idx]: parent_idx = left_idx else: value -= self.tree[left_idx] parent_idx = right_idx data_idx = leaf_idx - self.capacity + 1 return leaf_idx, self.tree[leaf_idx], data_idx
[docs] def total_priority(self) -> float: """ Returns the sum of all priorities. Returns: float: Total priority. """ return self.tree[0]
[docs] class PrioritizedReplayBuffer(BaseBuffer): """ A Prioritized Experience Replay Buffer using a SumTree for efficient sampling. Attributes: alpha (float): How much prioritization is used (0 = uniform, 1 = full prioritization). beta (float): Importance sampling bias correction term. priorities (SumTree): Sum tree to manage priorities. max_priority (float): The maximum priority value observed. """ def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4, device: str = 'cpu' ) -> None: super().__init__(capacity, device) self.alpha = alpha self.beta = beta self.beta0 = beta self.priorities = SumTree(capacity) self.max_priority = 1.0 self.buffer = {} self.initialized = False def _init_storage(self, experience: Dict[str, torch.Tensor]) -> None: """ Initializes the storage for the replay buffer based on the first transition. Args: experience (Dict[str, torch.Tensor]): A dictionary containing the transition data. """ for k, v in experience.items(): shape = (self.capacity,) + v.shape[1:] self.buffer[k] = torch.zeros(shape, dtype=v.dtype, device=self.device) self.initialized = True
[docs] def add(self, experience: Dict[str, torch.Tensor]) -> None: """ Adds a new transition to the replay buffer. Args: experience (Dict[str, torch.Tensor]): A dictionary containing the transition data. """ if not self.initialized: self._init_storage(experience) batch_size = next(iter(experience.values())).shape[0] insert_end = self.pos + batch_size if insert_end <= self.capacity: idx = slice(self.pos, insert_end) for k, v in experience.items(): self.buffer[k][idx] = v.to(self.device) else: first_len = self.capacity - self.pos second_len = batch_size - first_len for k, v in experience.items(): self.buffer[k][self.pos:] = v[:first_len].to(self.device) self.buffer[k][:second_len] = v[first_len:].to(self.device) for _ in range(batch_size): self.priorities.add(self.max_priority ** self.alpha) self.pos = (self.pos + batch_size) % self.capacity self.size = min(self.size + batch_size, self.capacity)
[docs] def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: """ Samples a batch of transitions from the replay buffer using prioritized sampling. Returns a dictionary containing: - 'weights': Importance sampling weights for each sample. - 'indices': Indices of the sampled transitions in the buffer. - Other transition data (e.g., state, action, reward, etc.). Args: batch_size (int): The number of samples to draw from the buffer. Returns: Dict[str, torch.Tensor]: A dictionary containing the sampled transitions. """ if self.size < batch_size: raise ValueError("Not enough samples in buffer to sample.") segment = self.priorities.total_priority() / batch_size indices = [] priorities = [] for i in range(batch_size): a = segment * i b = segment * (i + 1) s = np.random.uniform(a, b) leaf_idx, priority, data_idx = self.priorities.get_leaf(s) indices.append(data_idx) priorities.append(priority) indices_torch = torch.tensor(indices, dtype=torch.long, device=self.device) sampled = {k: v[indices_torch] for k, v in self.buffer.items()} priorities = torch.tensor(priorities, dtype=torch.float32, device=self.device) sampling_probabilities = priorities / self.priorities.total_priority() weights = (self.size * sampling_probabilities).pow(-self.beta) weights /= weights.max() sampled['weights'] = weights sampled['indices'] = indices_torch return sampled
[docs] def update_priorities(self, indices: torch.Tensor, td_errors: torch.Tensor) -> None: """ Update the priorities of the sampled transitions based on TD errors. Args: indices (torch.Tensor): Indices of the transitions to update. td_errors (torch.Tensor): TD errors for the transitions. """ priorities = (td_errors.abs() + 1e-6).pow(self.alpha) for idx, priority in zip(indices.tolist(), priorities.tolist()): tree_idx = idx + self.capacity - 1 self.priorities.update(tree_idx, priority) self.max_priority = max(self.max_priority, priority)
[docs] def clear(self) -> None: """ Clears the replay buffer, resetting its state. """ self.size = 0 self.pos = 0 self.buffer = {} self.initialized = False self.priorities = SumTree(self.capacity) self.max_priority = 1.0 self.beta = self.beta0
[docs] class RolloutBuffer(BaseBuffer):
[docs] def __init__(self, capacity: int, device: str = 'cpu' ) -> None: """ Args: capacity: Max number of transitions the buffer can store. device: Torch device. """ super().__init__(capacity, device) self.buffer: Dict[str, torch.Tensor] = {} self.initialized = False
def _init_storage(self, experience: Dict[str, torch.Tensor]) -> None: for k, v in experience.items(): shape = (self.capacity,) + v.shape[1:] self.buffer[k] = torch.zeros(shape, dtype=v.dtype, device=self.device) self.initialized = True
[docs] def add(self, experience: Dict[str, torch.Tensor]) -> None: if not self.initialized: self._init_storage(experience) batch_size = next(iter(experience.values())).shape[0] if self.size + batch_size > self.capacity: raise ValueError("RolloutBuffer overflow: not enough capacity") idx = slice(self.size, self.size + batch_size) for k, v in experience.items(): self.buffer[k][idx] = v.to(self.device) self.size += batch_size
[docs] def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: if self.size < batch_size: raise ValueError("Not enough samples to draw") indices = torch.randperm(self.size, device=self.device)[:batch_size] sampled = {k: v[indices] for k, v in self.buffer.items()} # Keep the remaining entries by copying them up keep_mask = torch.ones(self.size, dtype=bool, device=self.device) keep_mask[indices] = False keep_indices = keep_mask.nonzero(as_tuple=False).squeeze(-1) for k in self.buffer: self.buffer[k][:len(keep_indices)] = self.buffer[k][keep_indices] self.size -= batch_size return sampled
[docs] def get_batches(self, batch_size: int): """ Yields mini-batches in random order. The final batch may be smaller. After iteration, if drop_after_get=True, the buffer is cleared. """ if self.size == 0: return indices = torch.randperm(self.size, device=self.device) for i in range(0, self.size, batch_size): idx = indices[i:i + batch_size] batch = {k: v[idx] for k, v in self.buffer.items()} yield batch
[docs] def clear(self) -> None: self.size = 0 self.buffer = {} self.initialized = False