heads#
Classes#
Beta actor head using torch.distributions.Beta.
Categorical actor head for discrete action spaces.
Continuous action head that outputs raw actions for continuous action spaces.
Decision head for discrete action spaces that outputs raw logits and applies a decision function.
Dueling network head for discrete action spaces.
Diagonal Gaussian actor head for continuous action spaces.
State–action value function head.
Squashed (tanh) diagonal Gaussian actor head, commonly used in SAC.
State-value function head.
- class prt_rl.common.components.heads.heads.BetaHead(latent_dim: int, action_dim: int, *, concentration_offset: float = 1.0, min_concentration: float = 0.001, epsilon: float = 1e-06, action_low: Tensor | None = None, action_high: Tensor | None = None)[source]#
Beta actor head using torch.distributions.Beta.
Produces independent Beta(alpha_i, beta_i) per action dimension. Samples in (0,1) and optionally maps to [action_low, action_high].
- API:
sample(latent, deterministic=False) -> (action, log_prob, entropy)
log_prob(latent, action) -> (B,1) [action is final env-scaled action if bounds provided]
entropy(latent) -> (B,1)
Notes
Beta is defined on (0,1). We clamp actions for numerical stability.
If bounds are provided, we apply a change-of-variables correction to log_prob.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- entropy(latent: Tensor) Tensor[source]#
Compute the entropy of the distribution defined by the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
- Returns:
The entropy of the distribution.
- Return type:
entropy (Tensor)
- log_prob(latent: Tensor, action: Tensor) Tensor[source]#
Compute the log probability of a given action under the distribution defined by the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
action (Tensor) – The action for which to compute the log probability.
- Returns:
The log probability of the given action.
- Return type:
log_prob (Tensor)
- sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Tensor][source]#
Sample an action from the distribution given the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
deterministic (bool) – Whether to sample deterministically (e.g., take the mean) or stochastically.
- Returns:
The sampled action. log_prob (Tensor): The log probability of the sampled action. entropy (Tensor): The entropy of the distribution.
- Return type:
action (Tensor)
- class prt_rl.common.components.heads.heads.CategoricalHead(latent_dim: int, num_actions: int)[source]#
Categorical actor head for discrete action spaces.
- Contract:
- sample(latent, deterministic) -> (action, log_prob, entropy)
action: (B,) int64 log_prob: (B,1) float entropy: (B,1) float
log_prob(latent, action) -> (B,1)
entropy(latent) -> (B,1)
Notes
Uses torch.distributions.Categorical(logits=…)
Deterministic action is argmax over logits.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- entropy(latent: Tensor) Tensor[source]#
Compute the entropy of the distribution defined by the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
- Returns:
The entropy of the distribution.
- Return type:
entropy (Tensor)
- get_logits(latent: Tensor) Tensor[source]#
Get the raw logits output by the head for inspection or auxiliary losses.
- Parameters:
latent (Tensor) – Latent state representation of shape (B, latent_dim).
- Returns:
Logits of shape (B, num_actions).
- Return type:
Tensor
- log_prob(latent: Tensor, action: Tensor) Tensor[source]#
action expected shape: (B,) (dtype long) or (B,1) which will be squeezed.
- sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Tensor][source]#
Sample an action from the distribution given the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
deterministic (bool) – Whether to sample deterministically (e.g., take the mean) or stochastically.
- Returns:
The sampled action. log_prob (Tensor): The log probability of the sampled action. entropy (Tensor): The entropy of the distribution.
- Return type:
action (Tensor)
- class prt_rl.common.components.heads.heads.ContinuousHead(latent_dim: int, action_dim: int)[source]#
Continuous action head that outputs raw actions for continuous action spaces.
This head is typically used in deterministic policy algorithms where the policy directly outputs continuous action values without sampling from a distribution.
Notes
The output layer is linear (no activation), producing raw action values.
The expected output shape is (B, action_dim), where B is the batch size and action_dim is the dimensionality of the action space.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class prt_rl.common.components.heads.heads.DecisionHead(latent_dim: int, action_dim: int, *, decision_function: DecisionFunction = <prt_rl.common.decision_functions.Greedy object>)[source]#
Decision head for discrete action spaces that outputs raw logits and applies a decision function.
This head is typically used in deterministic policy algorithms where the policy directly outputs logits for discrete actions without sampling from a distribution.
Notes
The output layer is linear (no activation), producing raw logits.
The expected output shape is (B, action_dim), where B is the batch size and action_dim is the number of discrete actions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class prt_rl.common.components.heads.heads.DuelingHead(latent_dim: int, num_actions: int)[source]#
Dueling network head for discrete action spaces.
Implements the dueling architecture where separate streams estimate the state-value function V(s) and the advantage function A(s, a). The Q-values are computed as:
Q(s, a) = V(s) + A(s, a) - mean_a’ A(s, a’)
This head is commonly used in DQN variants to improve value estimation.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class prt_rl.common.components.heads.heads.GaussianHead(latent_dim: int, action_dim: int, *, log_std_init: float = -0.5, min_log_std: float = -20.0, max_log_std: float = 2.0, use_rsample: bool = False)[source]#
Diagonal Gaussian actor head for continuous action spaces.
- Parameterization:
mean = Linear(latent)
log_std is a learned parameter vector (state-independent)
- Contract:
- sample(latent, deterministic) -> (action, log_prob, entropy)
action: (B, act_dim) log_prob: (B,1) summed over action dims entropy: (B,1) summed over action dims
log_prob(latent, action) -> (B,1)
entropy(latent) -> (B,1)
Notes
Uses torch.distributions.Normal(mean, std) with diagonal independence.
We sum log_prob/entropy over action dims inside the head to keep callers DRY.
If you want state-dependent std, replace the Parameter with another Linear head.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- entropy(latent: Tensor) Tensor[source]#
Compute the entropy of the distribution defined by the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
- Returns:
The entropy of the distribution.
- Return type:
entropy (Tensor)
- sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Tensor][source]#
Sample an action from the distribution given the latent representation.
- Parameters:
latent (Tensor) – The input to the head, typically the output of a backbone network.
deterministic (bool) – Whether to sample deterministically (e.g., take the mean) or stochastically.
- Returns:
The sampled action. log_prob (Tensor): The log probability of the sampled action. entropy (Tensor): The entropy of the distribution.
- Return type:
action (Tensor)
- class prt_rl.common.components.heads.heads.QValueHead(latent_dim: int, action_dim: int)[source]#
State–action value function head.
Implements a scalar Q-function Q(s, a) that maps a latent state representation and a continuous action vector to a single value estimate per state–action pair.
This head is typically used by off-policy algorithms such as SAC and TD3, where the critic evaluates specific actions rather than all actions at once.
Notes
The latent state and action are concatenated along the feature dimension.
The output layer is linear (no activation), as Q-values are unbounded.
This implementation assumes continuous action spaces. For discrete action spaces (e.g., DQN), a different head that outputs Q(s, ·) is used.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(latent: Tensor, action: Tensor) Tensor[source]#
Compute the state–action value estimate.
- Parameters:
latent (Tensor) – Latent state representation of shape (B, latent_dim).
action (Tensor) – Action tensor of shape (B, action_dim).
- Returns:
Q-value estimates of shape (B, 1).
- Return type:
Tensor
- class prt_rl.common.components.heads.heads.TanhGaussianHead(latent_dim: int, action_dim: int, *, log_std_min: float = -20.0, log_std_max: float = 2.0, epsilon: float = 1e-06, action_low: Tensor | None = None, action_high: Tensor | None = None)[source]#
Squashed (tanh) diagonal Gaussian actor head, commonly used in SAC.
It parameterizes a Normal distribution in R^act_dim, samples with rsample(), then squashes via tanh to (-1, 1). Optionally scales/shifts to env bounds.
- Key detail: log_prob must include the tanh “change of variables” correction:
log pi(a) = log N(u; mu, std) - sum log(1 - tanh(u)^2) where a = tanh(u)
- Optional scaling to [low, high]:
a_env = scale * a + bias, scale=(high-low)/2, bias=(high+low)/2
- adds another correction term:
log pi(a_env) = log pi(a) - sum log(scale)
- API:
sample(latent, deterministic=False) -> (action, log_prob, info)
log_prob(latent, action) -> log_prob (B,1) [action is final env-scaled action]
entropy(…) is not analytic after tanh; SAC typically uses -log_prob as entropy term proxy.
Notes
Uses state-dependent mean and log_std by default (two linear layers).
If you want state-independent log_std, replace log_std_layer with nn.Parameter.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- entropy_proxy(latent: Tensor, action: Tensor) Tensor[source]#
SAC typically uses -log_prob as an entropy term (up to expectations). This returns (B,1) = -log pi(a).
You might prefer calling this neg_log_prob depending on your style.
- log_prob(latent: Tensor, action: Tensor) Tensor[source]#
Computes log pi(action) for a given final action (env-scaled if bounds were provided).
- action:
if bounds are None: expected in (-1,1)
if bounds are provided: expected in [low, high]
- sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Dict[str, Tensor]][source]#
- Returns:
(B, act_dim) final action (env-scaled if bounds provided) log_prob: (B, 1) log pi(action) info: dict containing:
”pre_tanh”: u (B, act_dim)
”tanh”: a in (-1,1) (B, act_dim)
”mean”: mean (B, act_dim)
”log_std”: log_std (B, act_dim)
- Return type:
action
- class prt_rl.common.components.heads.heads.ValueHead(latent_dim: int)[source]#
State-value function head.
Implements a scalar value function V(s) that maps a latent state representation (typically produced by an encoder and/or backbone network) to a single value estimate per state.
This head is commonly used by actor–critic algorithms such as PPO, A2C, and A3C.
Notes
The output layer is linear (no activation), as value functions are generally unbounded.
The expected output shape is (B, 1), where B is the batch size.
- Parameters:
latent_dim (int) – Dimension of the latent state representation.
Initialize internal Module state, shared by both nn.Module and ScriptModule.