import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Dict, Optional
import prt_rl.common.decision_functions as dfcn
from prt_rl.common.components.heads.interface import DistributionHead
[docs]
class CategoricalHead(DistributionHead):
"""
Categorical actor head for discrete action spaces.
Contract:
- sample(latent, deterministic) -> (action, log_prob, entropy)
action: (B,) int64
log_prob: (B,1) float
entropy: (B,1) float
- log_prob(latent, action) -> (B,1)
- entropy(latent) -> (B,1)
Notes:
- Uses torch.distributions.Categorical(logits=...)
- Deterministic action is argmax over logits.
"""
def __init__(self, latent_dim: int, num_actions: int) -> None:
super().__init__()
if num_actions <= 1:
raise ValueError(f"n_actions must be > 1, got {num_actions}")
self.n_actions = int(num_actions)
self.logits = nn.Linear(latent_dim, num_actions)
def _dist(self, latent: Tensor) -> torch.distributions.Categorical:
logits = self.logits(latent) # (B, n_actions)
return torch.distributions.Categorical(logits=logits)
[docs]
def sample(self, latent: Tensor, deterministic: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
dist = self._dist(latent)
if deterministic:
# dist.logits exists on Categorical; argmax on last dim yields (B,1)
action = dist.logits.argmax(dim=-1).unsqueeze(-1)
else:
action = dist.sample().unsqueeze(-1) # (B,1)
log_prob = dist.log_prob(action).unsqueeze(-1) # (B,1)
entropy = dist.entropy().unsqueeze(-1) # (B,1)
return action, log_prob, entropy
[docs]
def log_prob(self, latent: Tensor, action: Tensor) -> Tensor:
"""
action expected shape: (B,) (dtype long) or (B,1) which will be squeezed.
"""
if action.ndim == 2 and action.shape[-1] == 1:
action = action.squeeze(-1)
dist = self._dist(latent)
return dist.log_prob(action).unsqueeze(-1) # (B,1)
[docs]
def entropy(self, latent: Tensor) -> Tensor:
dist = self._dist(latent)
return dist.entropy().unsqueeze(-1) # (B,1)
[docs]
def get_logits(self, latent: Tensor) -> Tensor:
"""
Get the raw logits output by the head for inspection or auxiliary losses.
Args:
latent (Tensor): Latent state representation of shape (B, latent_dim).
Returns:
Tensor: Logits of shape (B, num_actions).
"""
return self.logits(latent)
[docs]
class ContinuousHead(nn.Module):
"""
Continuous action head that outputs raw actions for continuous action spaces.
This head is typically used in deterministic policy algorithms where the policy
directly outputs continuous action values without sampling from a distribution.
Notes:
- The output layer is linear (no activation), producing raw action values.
- The expected output shape is (B, action_dim), where B is the batch size
and action_dim is the dimensionality of the action space.
"""
def __init__(self, latent_dim: int, action_dim: int):
super().__init__()
self.action_layer = nn.Linear(latent_dim, action_dim)
[docs]
def forward(self, latent: Tensor) -> Tensor:
"""
Compute the continuous action output.
Args:
latent (Tensor): Latent state representation of shape (B, latent_dim).
Returns:
Tensor: Continuous action output of shape (B, action_dim).
"""
return self.action_layer(latent)
[docs]
class DecisionHead(nn.Module):
"""
Decision head for discrete action spaces that outputs raw logits and applies a decision function.
This head is typically used in deterministic policy algorithms where the policy
directly outputs logits for discrete actions without sampling from a distribution.
Notes:
- The output layer is linear (no activation), producing raw logits.
- The expected output shape is (B, action_dim), where B is the batch size
and action_dim is the number of discrete actions.
"""
def __init__(self,
latent_dim: int,
action_dim: int,
*,
decision_function: dfcn.DecisionFunction = dfcn.Greedy()
) -> None:
super().__init__()
self.qval_layer = nn.Linear(latent_dim, action_dim)
self.decision_function = decision_function
[docs]
def sample(self, latent: Tensor, deterministic: bool = False) -> Tensor:
"""
Compute the decision logits output.
Args:
latent (Tensor): Latent state representation of shape (B, latent_dim).
Returns:
Tensor: Actions of shape (B,).
"""
qvals = self.qval_layer(latent)
if deterministic:
# Deterministic action is argmax over logits.
actions = qvals.argmax(dim=-1, keepdim=True)
else:
actions = self.decision_function.select_action(qvals)
return actions
[docs]
class GaussianHead(DistributionHead):
"""
Diagonal Gaussian actor head for continuous action spaces.
Parameterization:
- mean = Linear(latent)
- log_std is a learned parameter vector (state-independent)
Contract:
- sample(latent, deterministic) -> (action, log_prob, entropy)
action: (B, act_dim)
log_prob: (B,1) summed over action dims
entropy: (B,1) summed over action dims
- log_prob(latent, action) -> (B,1)
- entropy(latent) -> (B,1)
Notes:
- Uses torch.distributions.Normal(mean, std) with diagonal independence.
- We sum log_prob/entropy over action dims inside the head to keep callers DRY.
- If you want state-dependent std, replace the Parameter with another Linear head.
"""
def __init__(
self,
latent_dim: int,
action_dim: int,
*,
log_std_init: float = -0.5,
min_log_std: float = -20.0,
max_log_std: float = 2.0,
use_rsample: bool = False,
) -> None:
super().__init__()
if action_dim <= 0:
raise ValueError(f"action_dim must be > 0, got {action_dim}")
self.action_dim = int(action_dim)
self.mean = nn.Linear(latent_dim, action_dim)
# state-independent log_std
self.log_std = nn.Parameter(torch.full((action_dim,), float(log_std_init)))
self.min_log_std = float(min_log_std)
self.max_log_std = float(max_log_std)
self.use_rsample = bool(use_rsample)
def _dist(self, latent: Tensor) -> torch.distributions.Normal:
mean = self.mean(latent) # (B, act_dim)
log_std = self.log_std.expand_as(mean)
log_std = torch.clamp(log_std, self.min_log_std, self.max_log_std)
std = log_std.exp()
return torch.distributions.Normal(mean, std)
[docs]
def sample(self, latent: Tensor, deterministic: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
dist = self._dist(latent)
if deterministic:
action = dist.mean
else:
action = dist.rsample() if self.use_rsample else dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) # (B,1)
entropy = dist.entropy().sum(dim=-1, keepdim=True) # (B,1)
return action, log_prob, entropy
[docs]
def log_prob(self, latent: Tensor, action: Tensor) -> Tensor:
"""
action expected shape: (B, act_dim)
"""
dist = self._dist(latent)
return dist.log_prob(action).sum(dim=-1, keepdim=True) # (B,1)
[docs]
def entropy(self, latent: Tensor) -> Tensor:
dist = self._dist(latent)
return dist.entropy().sum(dim=-1, keepdim=True) # (B,1)
[docs]
class TanhGaussianHead(nn.Module):
"""
Squashed (tanh) diagonal Gaussian actor head, commonly used in SAC.
It parameterizes a Normal distribution in R^act_dim, samples with rsample(),
then squashes via tanh to (-1, 1). Optionally scales/shifts to env bounds.
Key detail: log_prob must include the tanh "change of variables" correction:
log pi(a) = log N(u; mu, std) - sum log(1 - tanh(u)^2)
where a = tanh(u)
Optional scaling to [low, high]:
a_env = scale * a + bias, scale=(high-low)/2, bias=(high+low)/2
adds another correction term:
log pi(a_env) = log pi(a) - sum log(scale)
API:
- sample(latent, deterministic=False) -> (action, log_prob, info)
- log_prob(latent, action) -> log_prob (B,1) [action is final env-scaled action]
- entropy(...) is not analytic after tanh; SAC typically uses -log_prob as entropy term proxy.
Notes:
- Uses state-dependent mean and log_std by default (two linear layers).
- If you want state-independent log_std, replace log_std_layer with nn.Parameter.
"""
def __init__(
self,
latent_dim: int,
action_dim: int,
*,
log_std_min: float = -20.0,
log_std_max: float = 2.0,
epsilon: float = 1e-6,
action_low: Optional[Tensor] = None,
action_high: Optional[Tensor] = None,
) -> None:
super().__init__()
if action_dim <= 0:
raise ValueError(f"action_dim must be > 0, got {action_dim}")
self.action_dim = int(action_dim)
self.log_std_min = float(log_std_min)
self.log_std_max = float(log_std_max)
self.epsilon = float(epsilon)
self.mean_layer = nn.Linear(latent_dim, action_dim)
self.log_std_layer = nn.Linear(latent_dim, action_dim)
# Optional action rescaling to env bounds.
# Register as buffers so .to(device) moves them, and they are saved in state_dict.
if action_low is not None or action_high is not None:
if action_low is None or action_high is None:
raise ValueError("Provide both action_low and action_high or neither.")
if action_low.shape != (action_dim,) or action_high.shape != (action_dim,):
raise ValueError(
f"Expected action_low/high shape ({action_dim},), "
f"got {tuple(action_low.shape)} and {tuple(action_high.shape)}"
)
self.register_buffer("action_low", action_low.clone().detach())
self.register_buffer("action_high", action_high.clone().detach())
else:
self.action_low = None
self.action_high = None
def _base_dist(self, latent: Tensor) -> torch.distributions.Normal:
mean = self.mean_layer(latent) # (B, act_dim)
log_std = self.log_std_layer(latent) # (B, act_dim)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
std = log_std.exp()
return torch.distributions.Normal(mean, std)
def _squash(self, u: Tensor) -> Tensor:
return torch.tanh(u)
def _unsquash(self, a: Tensor) -> Tensor:
"""
Inverse tanh (atanh). Input a should be in (-1,1).
We clamp for numerical safety.
"""
a = torch.clamp(a, -1.0 + self.epsilon, 1.0 - self.epsilon)
return 0.5 * (torch.log1p(a) - torch.log1p(-a)) # atanh
def _apply_action_bounds(self, a: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""
Maps a in (-1,1) to env bounds, if provided.
Returns (a_env, log_abs_det_jacobian_scale) where the latter is (B,1) or None.
"""
if self.action_low is None:
return a, None
# scale and bias are (act_dim,)
scale = (self.action_high - self.action_low) / 2.0
bias = (self.action_high + self.action_low) / 2.0
a_env = a * scale + bias
# log|det d a_env / d a| = sum log(scale)
# constant per-dimension; broadcast to batch.
log_scale = torch.log(scale).sum().expand(a.shape[0], 1) # (B,1)
return a_env, log_scale
def _remove_action_bounds(self, a_env: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""
Inverse of _apply_action_bounds. Maps env-scaled action back to (-1,1).
Returns (a, log_scale) where log_scale matches the forward mapping.
"""
if self.action_low is None:
return a_env, None
scale = (self.action_high - self.action_low) / 2.0
bias = (self.action_high + self.action_low) / 2.0
a = (a_env - bias) / scale
log_scale = torch.log(scale).sum().expand(a_env.shape[0], 1) # (B,1)
return a, log_scale
[docs]
def sample(
self,
latent: Tensor,
deterministic: bool = False,
) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]:
"""
Returns:
action: (B, act_dim) final action (env-scaled if bounds provided)
log_prob: (B, 1) log pi(action)
info: dict containing:
- "pre_tanh": u (B, act_dim)
- "tanh": a in (-1,1) (B, act_dim)
- "mean": mean (B, act_dim)
- "log_std": log_std (B, act_dim)
"""
dist = self._base_dist(latent)
mean = dist.loc
std = dist.scale
log_std = torch.log(std + 1e-12)
if deterministic:
u = mean
else:
u = dist.rsample() # reparameterized sample
a = self._squash(u) # (-1,1)
# Log prob in squashed space (before optional scaling)
# log N(u) - sum log(1 - tanh(u)^2)
log_prob_u = dist.log_prob(u).sum(dim=-1, keepdim=True) # (B,1)
# stable correction: log(1 - tanh(u)^2) = log(1 - a^2)
correction = torch.log(1.0 - a.pow(2) + self.epsilon).sum(dim=-1, keepdim=True) # (B,1)
log_prob = log_prob_u - correction # (B,1)
# Optional env scaling correction
action, log_scale = self._apply_action_bounds(a)
if log_scale is not None:
log_prob = log_prob - log_scale
info = {"pre_tanh": u, "tanh": a, "mean": mean, "log_std": log_std}
return action, log_prob, info
[docs]
def log_prob(self, latent: Tensor, action: Tensor) -> Tensor:
"""
Computes log pi(action) for a given *final* action (env-scaled if bounds were provided).
action:
- if bounds are None: expected in (-1,1)
- if bounds are provided: expected in [low, high]
"""
# Map env-scaled action back to tanh-space a in (-1,1)
a, log_scale = self._remove_action_bounds(action)
# Inverse tanh to get u
u = self._unsquash(a)
dist = self._base_dist(latent)
log_prob_u = dist.log_prob(u).sum(dim=-1, keepdim=True) # (B,1)
correction = torch.log(1.0 - a.pow(2) + self.epsilon).sum(dim=-1, keepdim=True) # (B,1)
log_prob = log_prob_u - correction
if log_scale is not None:
log_prob = log_prob - log_scale
return log_prob
[docs]
def entropy_proxy(self, latent: Tensor, action: Tensor) -> Tensor:
"""
SAC typically uses -log_prob as an entropy term (up to expectations).
This returns (B,1) = -log pi(a).
You might prefer calling this `neg_log_prob` depending on your style.
"""
return -self.log_prob(latent, action)
[docs]
class BetaHead(DistributionHead):
"""
Beta actor head using torch.distributions.Beta.
Produces independent Beta(alpha_i, beta_i) per action dimension.
Samples in (0,1) and optionally maps to [action_low, action_high].
API:
- sample(latent, deterministic=False) -> (action, log_prob, entropy)
- log_prob(latent, action) -> (B,1) [action is final env-scaled action if bounds provided]
- entropy(latent) -> (B,1)
Notes:
- Beta is defined on (0,1). We clamp actions for numerical stability.
- If bounds are provided, we apply a change-of-variables correction to log_prob.
"""
def __init__(
self,
latent_dim: int,
action_dim: int,
*,
concentration_offset: float = 1.0, # encourages alpha,beta >= 1 initially
min_concentration: float = 1e-3, # numerical floor
epsilon: float = 1e-6, # clamp for actions near {0,1}
action_low: Optional[Tensor] = None,
action_high: Optional[Tensor] = None,
) -> None:
super().__init__()
if action_dim <= 0:
raise ValueError(f"action_dim must be > 0, got {action_dim}")
self.action_dim = int(action_dim)
self.concentration_offset = float(concentration_offset)
self.min_concentration = float(min_concentration)
self.epsilon = float(epsilon)
self.alpha_layer = nn.Linear(latent_dim, action_dim)
self.beta_layer = nn.Linear(latent_dim, action_dim)
self.softplus = nn.Softplus()
# Optional bounds
if action_low is not None or action_high is not None:
if action_low is None or action_high is None:
raise ValueError("Provide both action_low and action_high or neither.")
if action_low.shape != (action_dim,) or action_high.shape != (action_dim,):
raise ValueError(
f"Expected action_low/high shape ({action_dim},), "
f"got {tuple(action_low.shape)} and {tuple(action_high.shape)}"
)
self.register_buffer("action_low", action_low.clone().detach())
self.register_buffer("action_high", action_high.clone().detach())
else:
self.action_low = None
self.action_high = None
def _concentrations(self, latent: Tensor) -> Tuple[Tensor, Tensor]:
# softplus -> (0, inf). offset helps avoid extreme U-shapes early in training.
alpha = self.softplus(self.alpha_layer(latent)) + self.concentration_offset
beta = self.softplus(self.beta_layer(latent)) + self.concentration_offset
# hard floor for numerical safety
alpha = torch.clamp(alpha, min=self.min_concentration)
beta = torch.clamp(beta, min=self.min_concentration)
return alpha, beta
def _dist(self, latent: Tensor) -> torch.distributions.Beta:
alpha, beta = self._concentrations(latent)
# torch.distributions.Beta uses concentration1 (alpha) and concentration0 (beta)
return torch.distributions.Beta(concentration1=alpha, concentration0=beta)
def _apply_bounds(self, a01: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""
Map from (0,1) to [low, high] if bounds provided.
Returns (action_env, log_scale) where log_scale is (B,1) = sum log(scale).
"""
if self.action_low is None:
return a01, None
scale = (self.action_high - self.action_low) # (act_dim,)
action = self.action_low + a01 * scale
log_scale = torch.log(scale).sum().expand(a01.shape[0], 1) # (B,1)
return action, log_scale
def _remove_bounds(self, action_env: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""
Inverse of _apply_bounds.
Returns (a01, log_scale) where log_scale matches forward mapping.
"""
if self.action_low is None:
return action_env, None
scale = (self.action_high - self.action_low)
a01 = (action_env - self.action_low) / scale
log_scale = torch.log(scale).sum().expand(action_env.shape[0], 1)
return a01, log_scale
[docs]
def sample(self, latent: Tensor, deterministic: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
dist = self._dist(latent)
if deterministic:
# Mean of Beta(alpha,beta) = alpha/(alpha+beta)
alpha = dist.concentration1
beta = dist.concentration0
a01 = alpha / (alpha + beta)
else:
a01 = dist.sample()
# Keep away from boundaries for stable log_prob
a01 = torch.clamp(a01, self.epsilon, 1.0 - self.epsilon)
log_prob_01 = dist.log_prob(a01).sum(dim=-1, keepdim=True) # (B,1)
entropy_01 = dist.entropy().sum(dim=-1, keepdim=True) # (B,1)
action, log_scale = self._apply_bounds(a01)
log_prob = log_prob_01 - log_scale if log_scale is not None else log_prob_01
return action, log_prob, entropy_01
[docs]
def log_prob(self, latent: Tensor, action: Tensor) -> Tensor:
a01, log_scale = self._remove_bounds(action)
a01 = torch.clamp(a01, self.epsilon, 1.0 - self.epsilon)
dist = self._dist(latent)
log_prob_01 = dist.log_prob(a01).sum(dim=-1, keepdim=True)
return log_prob_01 - log_scale if log_scale is not None else log_prob_01
[docs]
def entropy(self, latent: Tensor) -> Tensor:
dist = self._dist(latent)
return dist.entropy().sum(dim=-1, keepdim=True)
[docs]
class ValueHead(nn.Module):
"""
State-value function head.
Implements a scalar value function V(s) that maps a latent state representation
(typically produced by an encoder and/or backbone network) to a single value
estimate per state.
This head is commonly used by actor–critic algorithms such as PPO, A2C, and A3C.
Notes:
- The output layer is linear (no activation), as value functions are
generally unbounded.
- The expected output shape is (B, 1), where B is the batch size.
Args:
latent_dim (int): Dimension of the latent state representation.
"""
def __init__(self, latent_dim: int):
super().__init__()
self.v = nn.Linear(latent_dim, 1)
[docs]
def forward(self, latent: Tensor) -> Tensor:
"""
Compute the state-value estimate.
Args:
latent (Tensor): Latent state representation of shape (B, latent_dim).
Returns:
Tensor: Value estimates of shape (B, 1).
"""
return self.v(latent)
[docs]
class QValueHead(nn.Module):
"""
State–action value function head.
Implements a scalar Q-function Q(s, a) that maps a latent state representation
and a continuous action vector to a single value estimate per state–action pair.
This head is typically used by off-policy algorithms such as SAC and TD3,
where the critic evaluates specific actions rather than all actions at once.
Notes:
- The latent state and action are concatenated along the feature dimension.
- The output layer is linear (no activation), as Q-values are unbounded.
- This implementation assumes continuous action spaces. For discrete
action spaces (e.g., DQN), a different head that outputs Q(s, ·) is used.
Args:
latent_dim (int): Dimension of the latent state representation.
action_dim (int): Dimension of the action vector.
"""
def __init__(self, latent_dim: int, action_dim: int):
super().__init__()
self.q = nn.Linear(latent_dim + action_dim, 1)
[docs]
def forward(self, latent: Tensor, action: Tensor) -> Tensor:
"""
Compute the state–action value estimate.
Args:
latent (Tensor): Latent state representation of shape (B, latent_dim).
action (Tensor): Action tensor of shape (B, action_dim).
Returns:
Tensor: Q-value estimates of shape (B, 1).
"""
q_input = torch.cat([latent, action], dim=1)
return self.q(q_input)
[docs]
class DuelingHead(nn.Module):
"""
Dueling network head for discrete action spaces.
Implements the dueling architecture where separate streams estimate
the state-value function V(s) and the advantage function A(s, a).
The Q-values are computed as:
Q(s, a) = V(s) + A(s, a) - mean_a' A(s, a')
This head is commonly used in DQN variants to improve value estimation.
Args:
latent_dim (int): Dimension of the latent state representation.
num_actions (int): Number of discrete actions.
"""
def __init__(self, latent_dim: int, num_actions: int):
super().__init__()
if num_actions <= 1:
raise ValueError(f"num_actions must be > 1, got {num_actions}")
self.num_actions = int(num_actions)
self.value_stream = nn.Linear(latent_dim, 1)
self.advantage_stream = nn.Linear(latent_dim, num_actions)
[docs]
def forward(self, latent: Tensor) -> Tensor:
"""
Compute the Q-value estimates for all actions.
Args:
latent (Tensor): Latent state representation of shape (B, latent_dim).
Returns:
Tensor: Q-value estimates of shape (B, num_actions).
"""
V = self.value_stream(latent) # (B, 1)
A = self.advantage_stream(latent) # (B, num_actions)
A_mean = A.mean(dim=1, keepdim=True) # (B, 1)
Q = V + A - A_mean # (B, num_actions)
return Q