collectors#
Collectors gather experience from environments using the provided policy/agent.
Classes#
The Parallel Collector collects experience from multiple environments in parallel.
Tracks collection metrics and logs ONLY when episodes finish.
- class prt_rl.common.collectors.Collector(env: EnvironmentInterface, logger: Logger | None = None, flatten: bool = True)[source]#
The Parallel Collector collects experience from multiple environments in parallel.
The parallel collector can collect experiences which returns a specific number of environment steps or specific number of trajectories. If you are collecting experience and the environment is done, but the number of steps is not reached, the environment is reset and continues collecting.
Note
Do not collect trajectories with an environment that never ends (i.e. done is never True) as the collector will never return. In this case collect experiences instead.
- Parameters:
env (EnvironmentInterface) – The environment to collect experience from.
logger (Logger | None) – Optional logger for logging information. Defaults to a new Logger instance.
flatten (bool) – Whether to flatten the collected experience. If flattened the output shape will be (N*T, …), but if not flattened it will be (N, T, …). Defaults to True.
- collect_experience(policy: Policy, num_steps: int = 1, bootstrap: bool = True) Dict[str, Tensor][source]#
Collects the given number of experiences from the environment using the provided policy.
The experiences are collected across all environments, so the actual number of steps is ceil(num_steps / N) where N is the number of environments. The output shape is (T, N, …) if not flattened, or (N*T, …) if flattened.
- Parameters:
policy (Policy) – A policy that implements the Policy interface.
num_steps (int) – The number of steps to collect experience for. Defaults to 1.
bootstrap (bool) – Whether to compute the last value estimate V(s_{T+1}) for bootstrapping if the last step is not done and the policy provides value estimates. Defaults to True.
- Returns:
- A dictionary containing the collected experience with keys:
’state’: The states collected. Shape (T, N, …), or (N*T, …) if flattened.
’action’: The actions taken. Shape (T, N, …), or (N*T, …) if flattened.
’next_state’: The next states after taking the actions. Shape (T, N, …), or (N*T, …) if flattened.
’reward’: The rewards received. Shape (T, N, 1), or (N*T, 1) if flattened.
’done’: The done flags indicating if the episode has ended. Shape (T, N, 1), or (N*T, 1) if flattened.
All keys from the policy’s info dictionary (e.g., ‘value’, ‘log_prob’, etc.)
’last_value_est’ (optional): The last value estimate for bootstrapping, if applicable. (N, 1)
- Return type:
Dict[str, torch.Tensor]
- collect_trajectory(policy: Policy, num_trajectories: int | None = None, min_num_steps: int | None = None) Dict[str, Tensor][source]#
Collects full trajectories in parallel from the environment using the provided policy.
If the number of trajectories specified matches the number of environments, it will collect one trajectory from each environment. If the number of trajectories is less than the number of environments, it will collect the specified number of trajectories from the first N environments. If the number of trajectories is greater than the number of environments, it will collect num_trajectories // N trajectories from each environment, where N is the number of environments, and then get the remaining trajectories from whichever environments complete first.
The output is a dictionary with keys (state, action, next_state, reward, done) where each key contains a tensor with the first dimension (B, …) where B is the sum of each trajectories timesteps T.
- Parameters:
policy (Policy | None) – The policy or agent to use.
num_trajectories (int | None) – The total number of complete trajectories to collect.
min_num_steps (int | None) – The minimum number of steps to collect before completing the trajectories. If specified, will collect until the minimum number of steps is reached, then complete the last trajectory.
- Returns:
- A dictionary containing the collected experience with keys:
state: The current state of the environment. Shape (B, state_dim)
action: The action taken by the policy. Shape (B, action_dim)
next_state: The next state after taking the action. Shape (B, state_dim)
reward: The reward received from the environment. Shape (B, 1)
done: The done flag indicating if the episode has ended. Shape (B, 1)
All keys from the policy’s info dictionary (e.g., ‘value’, ‘log_prob’, etc.)
- Return type:
Dict[str, torch.Tensor]
- get_metric_tracker() MetricsTracker[source]#
Returns the internal MetricsTracker instance for accessing collection metrics.
- Returns:
The internal MetricsTracker instance.
- Return type:
- class prt_rl.common.collectors.MetricsTracker(num_envs: int, logger: Logger | None = None)[source]#
Tracks collection metrics and logs ONLY when episodes finish. Counts are in env-steps: one vectorized step across N envs adds N.
Note
This class is designed to be used with single or vectorized environments. If multiple environments emit done on the same step, an episode reward will be logged for each environment with the same environment step value.
- Parameters: