Source code for prt_rl.common.components.heads.interface
from abc import ABC, abstractmethod
import torch
from torch import nn, Tensor
from typing import Tuple
[docs]
class DistributionHead(ABC, nn.Module):
"""
Interface for distribution heads that output a distribution over actions.
"""
[docs]
@abstractmethod
def sample(self, latent: Tensor, deterministic: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
"""
Sample an action from the distribution given the latent representation.
Args:
latent (Tensor): The input to the head, typically the output of a backbone network.
deterministic (bool): Whether to sample deterministically (e.g., take the mean) or stochastically.
Returns:
action (Tensor): The sampled action.
log_prob (Tensor): The log probability of the sampled action.
entropy (Tensor): The entropy of the distribution.
"""
raise NotImplementedError("The sample method must be implemented by subclasses.")
[docs]
def log_prob(self, latent: Tensor, action: Tensor) -> Tensor:
"""
Compute the log probability of a given action under the distribution defined by the latent representation.
Args:
latent (Tensor): The input to the head, typically the output of a backbone network.
action (Tensor): The action for which to compute the log probability.
Returns:
log_prob (Tensor): The log probability of the given action.
"""
raise NotImplementedError("The log_prob method must be implemented by subclasses.")
[docs]
def entropy(self, latent: Tensor) -> Tensor:
"""
Compute the entropy of the distribution defined by the latent representation.
Args:
latent (Tensor): The input to the head, typically the output of a backbone network.
Returns:
entropy (Tensor): The entropy of the distribution.
"""
raise NotImplementedError("The entropy method must be implemented by subclasses.")