interface#
Classes#
- class prt_rl.common.components.heads.interface.DistributionHead(*args: Any, **kwargs: Any)[source]#
Interface for distribution heads that output a distribution over actions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- __init__(*args: Any, **kwargs: Any) None#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- entropy(latent: Tensor) Tensor[source]#
Compute the entropy of the distribution defined by the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
- Returns:
The entropy of the distribution.
- Return type:
entropy (Tensor)
- log_prob(latent: Tensor, action: Tensor) Tensor[source]#
Compute the log probability of a given action under the distribution defined by the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
action (Tensor) – The action for which to compute the log probability.
- Returns:
The log probability of the given action.
- Return type:
log_prob (Tensor)
- abstractmethod sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Tensor][source]#
Sample an action from the distribution given the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
deterministic (bool) – Whether to sample deterministically (e.g., take the mean) or stochastically.
- Returns:
The sampled action. log_prob (Tensor): The log probability of the sampled action. entropy (Tensor): The entropy of the distribution.
- Return type:
action (Tensor)