buffers#

Classes#

BaseBuffer

PrioritizedReplayBuffer

A Prioritized Experience Replay Buffer using a SumTree for efficient sampling.

ReplayBuffer

A circular replay buffer that overwrites old experiences when full.

RolloutBuffer

SumTree

A binary sum tree for efficient sampling of elements proportional to their priority.

class prt_rl.common.buffers.BaseBuffer(capacity: int, device: str = 'cpu')[source]#
abstractmethod add(experience: Dict[str, Tensor]) None[source]#

Adds a new experience to the replay buffer. :param experience: A dictionary containing the experience data. :type experience: Dict[str, torch.Tensor]

abstractmethod clear() None[source]#

Clears the replay buffer, resetting its size and position.

get_size() int[source]#

Returns the current number of elements in the replay buffer. :returns: The current size of the replay buffer. :rtype: int

abstractmethod sample(batch_size: int) Dict[str, Tensor][source]#

Samples a batch of experiences from the replay buffer. :param batch_size: The number of samples to draw from the buffer. :type batch_size: int

Returns:

A dictionary containing the sampled experiences.

Return type:

Dict[str, torch.Tensor]

class prt_rl.common.buffers.PrioritizedReplayBuffer(capacity: int, alpha: float = 0.6, beta: float = 0.4, device: str = 'cpu')[source]#

A Prioritized Experience Replay Buffer using a SumTree for efficient sampling.

Variables:
  • alpha (float) – How much prioritization is used (0 = uniform, 1 = full prioritization).

  • beta (float) – Importance sampling bias correction term.

  • priorities (SumTree) – Sum tree to manage priorities.

  • max_priority (float) – The maximum priority value observed.

add(experience: Dict[str, Tensor]) None[source]#

Adds a new transition to the replay buffer. :param experience: A dictionary containing the transition data. :type experience: Dict[str, torch.Tensor]

clear() None[source]#

Clears the replay buffer, resetting its state.

get_size() int#

Returns the current number of elements in the replay buffer. :returns: The current size of the replay buffer. :rtype: int

sample(batch_size: int) Dict[str, Tensor][source]#

Samples a batch of transitions from the replay buffer using prioritized sampling. Returns a dictionary containing:

  • ‘weights’: Importance sampling weights for each sample.

  • ‘indices’: Indices of the sampled transitions in the buffer.

  • Other transition data (e.g., state, action, reward, etc.).

Parameters:

batch_size (int) – The number of samples to draw from the buffer.

Returns:

A dictionary containing the sampled transitions.

Return type:

Dict[str, torch.Tensor]

update_priorities(indices: Tensor, td_errors: Tensor) None[source]#

Update the priorities of the sampled transitions based on TD errors.

Parameters:
  • indices (torch.Tensor) – Indices of the transitions to update.

  • td_errors (torch.Tensor) – TD errors for the transitions.

class prt_rl.common.buffers.ReplayBuffer(capacity: int, device: device = device(type='cpu'))[source]#

A circular replay buffer that overwrites old experiences when full.

Parameters:
  • capacity (int) – The maximum number of experiences to store.

  • device (torch.device) – The device to store the buffer on (default: CPU).

add(experience: Dict[str, Tensor]) None[source]#

Adds a new transition to the replay buffer.

Parameters:

experience (Dict[str, torch.Tensor]) – A dictionary containing the transition data.

clear() None[source]#

Clears the replay buffer, resetting its state.

get_batches(batch_size: int)[source]#

Yields shuffled mini-batches from the buffer.

get_metadata() Dict | None[source]#

Returns the metadata stored with the replay buffer, if any.

get_size() int#

Returns the current number of elements in the replay buffer. :returns: The current size of the replay buffer. :rtype: int

classmethod load(path: str | Path, device: str = 'cpu') ReplayBuffer[source]#

Loads a replay buffer from a file.

Parameters:
  • path (str | Path) – Path to the saved buffer file.

  • device (str) – Device to load the buffer to. Defaults to “cpu”.

Returns:

A ReplayBuffer instance with restored data.

Return type:

ReplayBuffer

resize(new_capacity: int)[source]#

Expands the buffer to a new capacity while preserving existing data. :param new_capacity: The new buffer capacity. :type new_capacity: int

sample(batch_size: int) Dict[str, Tensor][source]#

Samples a batch of transitions from the replay buffer. :param batch_size: The number of samples to draw from the buffer. :type batch_size: int

Returns:

A dictionary containing the sampled transitions.

Return type:

Dict[str, torch.Tensor]

save(path: str | Path) None[source]#

Saves the replay buffer to a file.

Parameters:

path (str | Path) – Path to the file where the buffer will be saved.

set_metadata(metadata: Dict) None[source]#

Sets the metadata for the replay buffer.

Parameters:

metadata (Dict) – A dictionary containing metadata information.

class prt_rl.common.buffers.RolloutBuffer(capacity: int, device: str = 'cpu')[source]#
Parameters:
  • capacity – Max number of transitions the buffer can store.

  • device – Torch device.

__init__(capacity: int, device: str = 'cpu') None[source]#
Parameters:
  • capacity – Max number of transitions the buffer can store.

  • device – Torch device.

add(experience: Dict[str, Tensor]) None[source]#

Adds a new experience to the replay buffer. :param experience: A dictionary containing the experience data. :type experience: Dict[str, torch.Tensor]

clear() None[source]#

Clears the replay buffer, resetting its size and position.

get_batches(batch_size: int)[source]#

Yields mini-batches in random order. The final batch may be smaller.

After iteration, if drop_after_get=True, the buffer is cleared.

get_size() int#

Returns the current number of elements in the replay buffer. :returns: The current size of the replay buffer. :rtype: int

sample(batch_size: int) Dict[str, Tensor][source]#

Samples a batch of experiences from the replay buffer. :param batch_size: The number of samples to draw from the buffer. :type batch_size: int

Returns:

A dictionary containing the sampled experiences.

Return type:

Dict[str, torch.Tensor]

class prt_rl.common.buffers.SumTree(capacity: int)[source]#

A binary sum tree for efficient sampling of elements proportional to their priority.

The SumTree is a binary tree where each parent node stores the sum of its child nodes. It’s particularly useful for:

  • Efficient sampling from a discrete probability distribution:
    • You can sample an index i proportional to a weight p_i in O(log N) time.

  • Efficient dynamic updates of weights (e.g., priority values) while maintaining cumulative structure.

Key Applications:

  • Prioritized Experience Replay (PER): Sample transitions with probability ∝ priority.

  • Importance Sampling in any algorithm that requires sampling proportional to some non-uniform, changeable weights.

  • Event scheduling in simulations: Where some events happen more frequently than others.

Variables:
  • capacity (int) – Maximum number of elements the tree can hold.

  • tree (np.ndarray) – Binary tree storing priorities.

  • data_pointer (int) – Current position to insert new priority.

add(priority: float) None[source]#

Add a new priority to the sum tree.

Parameters:

priority (float) – Priority value to insert.

get_leaf(value: float) Tuple[int, float, int][source]#

Traverse the tree to find the leaf node corresponding to the sample value.

Parameters:

value (float) – A sample value in [0, total_priority).

Returns:

(tree index, priority, data index)

Return type:

Tuple[int, float, int]

total_priority() float[source]#

Returns the sum of all priorities.

Returns:

Total priority.

Return type:

float

update(tree_idx: int, priority: float) None[source]#

Update the priority at a specific tree index and propagate the change.

Parameters:
  • tree_idx (int) – Index in the tree to update.

  • priority (float) – New priority value.