interface#

Classes#

class prt_rl.common.components.heads.interface.DistributionHead(*args: Any, **kwargs: Any)[source]#

Interface for distribution heads that output a distribution over actions.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

__init__(*args: Any, **kwargs: Any) None#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

entropy(latent: Tensor) Tensor[source]#

Compute the entropy of the distribution defined by the latent representation.

Parameters:

latent (Tensor) – The input to the head, typically the output of a backbone network.

Returns:

The entropy of the distribution.

Return type:

entropy (Tensor)

log_prob(latent: Tensor, action: Tensor) Tensor[source]#

Compute the log probability of a given action under the distribution defined by the latent representation.

Parameters:
  • latent (Tensor) – The input to the head, typically the output of a backbone network.

  • action (Tensor) – The action for which to compute the log probability.

Returns:

The log probability of the given action.

Return type:

log_prob (Tensor)

abstractmethod sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Tensor][source]#

Sample an action from the distribution given the latent representation.

Parameters:
  • latent (Tensor) – The input to the head, typically the output of a backbone network.

  • deterministic (bool) – Whether to sample deterministically (e.g., take the mean) or stochastically.

Returns:

The sampled action. log_prob (Tensor): The log probability of the sampled action. entropy (Tensor): The entropy of the distribution.

Return type:

action (Tensor)