buffers#
Classes#
A Prioritized Experience Replay Buffer using a SumTree for efficient sampling.
A circular replay buffer that overwrites old experiences when full.
A binary sum tree for efficient sampling of elements proportional to their priority.
- class prt_rl.common.buffers.BaseBuffer(capacity: int, device: str = 'cpu')[source]#
- abstractmethod add(experience: Dict[str, Tensor]) None[source]#
Adds a new experience to the replay buffer. :param experience: A dictionary containing the experience data. :type experience: Dict[str, torch.Tensor]
- get_size() int[source]#
Returns the current number of elements in the replay buffer. :returns: The current size of the replay buffer. :rtype: int
- abstractmethod sample(batch_size: int) Dict[str, Tensor][source]#
Samples a batch of experiences from the replay buffer. :param batch_size: The number of samples to draw from the buffer. :type batch_size: int
- Returns:
A dictionary containing the sampled experiences.
- Return type:
Dict[str, torch.Tensor]
- class prt_rl.common.buffers.PrioritizedReplayBuffer(capacity: int, alpha: float = 0.6, beta: float = 0.4, device: str = 'cpu')[source]#
A Prioritized Experience Replay Buffer using a SumTree for efficient sampling.
- Variables:
- add(experience: Dict[str, Tensor]) None[source]#
Adds a new transition to the replay buffer. :param experience: A dictionary containing the transition data. :type experience: Dict[str, torch.Tensor]
- get_size() int#
Returns the current number of elements in the replay buffer. :returns: The current size of the replay buffer. :rtype: int
- sample(batch_size: int) Dict[str, Tensor][source]#
Samples a batch of transitions from the replay buffer using prioritized sampling. Returns a dictionary containing:
‘weights’: Importance sampling weights for each sample.
‘indices’: Indices of the sampled transitions in the buffer.
Other transition data (e.g., state, action, reward, etc.).
- Parameters:
batch_size (int) – The number of samples to draw from the buffer.
- Returns:
A dictionary containing the sampled transitions.
- Return type:
Dict[str, torch.Tensor]
- update_priorities(indices: Tensor, td_errors: Tensor) None[source]#
Update the priorities of the sampled transitions based on TD errors.
- Parameters:
indices (torch.Tensor) – Indices of the transitions to update.
td_errors (torch.Tensor) – TD errors for the transitions.
- class prt_rl.common.buffers.ReplayBuffer(capacity: int, device: device = device(type='cpu'))[source]#
A circular replay buffer that overwrites old experiences when full.
- Parameters:
capacity (int) – The maximum number of experiences to store.
device (torch.device) – The device to store the buffer on (default: CPU).
- add(experience: Dict[str, Tensor]) None[source]#
Adds a new transition to the replay buffer.
- Parameters:
experience (Dict[str, torch.Tensor]) – A dictionary containing the transition data.
- get_size() int#
Returns the current number of elements in the replay buffer. :returns: The current size of the replay buffer. :rtype: int
- classmethod load(path: str | Path, device: str = 'cpu') ReplayBuffer[source]#
Loads a replay buffer from a file.
- Parameters:
- Returns:
A ReplayBuffer instance with restored data.
- Return type:
- resize(new_capacity: int)[source]#
Expands the buffer to a new capacity while preserving existing data. :param new_capacity: The new buffer capacity. :type new_capacity: int
- sample(batch_size: int) Dict[str, Tensor][source]#
Samples a batch of transitions from the replay buffer. :param batch_size: The number of samples to draw from the buffer. :type batch_size: int
- Returns:
A dictionary containing the sampled transitions.
- Return type:
Dict[str, torch.Tensor]
- class prt_rl.common.buffers.RolloutBuffer(capacity: int, device: str = 'cpu')[source]#
- Parameters:
capacity – Max number of transitions the buffer can store.
device – Torch device.
- __init__(capacity: int, device: str = 'cpu') None[source]#
- Parameters:
capacity – Max number of transitions the buffer can store.
device – Torch device.
- add(experience: Dict[str, Tensor]) None[source]#
Adds a new experience to the replay buffer. :param experience: A dictionary containing the experience data. :type experience: Dict[str, torch.Tensor]
- get_batches(batch_size: int)[source]#
Yields mini-batches in random order. The final batch may be smaller.
After iteration, if drop_after_get=True, the buffer is cleared.
- class prt_rl.common.buffers.SumTree(capacity: int)[source]#
A binary sum tree for efficient sampling of elements proportional to their priority.
The SumTree is a binary tree where each parent node stores the sum of its child nodes. It’s particularly useful for:
- Efficient sampling from a discrete probability distribution:
You can sample an index i proportional to a weight p_i in O(log N) time.
Efficient dynamic updates of weights (e.g., priority values) while maintaining cumulative structure.
Key Applications:
Prioritized Experience Replay (PER): Sample transitions with probability ∝ priority.
Importance Sampling in any algorithm that requires sampling proportional to some non-uniform, changeable weights.
Event scheduling in simulations: Where some events happen more frequently than others.
- Variables:
- add(priority: float) None[source]#
Add a new priority to the sum tree.
- Parameters:
priority (float) – Priority value to insert.
- get_leaf(value: float) Tuple[int, float, int][source]#
Traverse the tree to find the leaf node corresponding to the sample value.