heads#

Classes#

BetaHead

Beta actor head using torch.distributions.Beta.

CategoricalHead

Categorical actor head for discrete action spaces.

ContinuousHead

Continuous action head that outputs raw actions for continuous action spaces.

DecisionHead

Decision head for discrete action spaces that outputs raw logits and applies a decision function.

DuelingHead

Dueling network head for discrete action spaces.

GaussianHead

Diagonal Gaussian actor head for continuous action spaces.

QValueHead

State–action value function head.

TanhGaussianHead

Squashed (tanh) diagonal Gaussian actor head, commonly used in SAC.

ValueHead

State-value function head.

class prt_rl.common.components.heads.heads.BetaHead(latent_dim: int, action_dim: int, *, concentration_offset: float = 1.0, min_concentration: float = 0.001, epsilon: float = 1e-06, action_low: Tensor | None = None, action_high: Tensor | None = None)[source]#

Beta actor head using torch.distributions.Beta.

Produces independent Beta(alpha_i, beta_i) per action dimension. Samples in (0,1) and optionally maps to [action_low, action_high].

API:
  • sample(latent, deterministic=False) -> (action, log_prob, entropy)

  • log_prob(latent, action) -> (B,1) [action is final env-scaled action if bounds provided]

  • entropy(latent) -> (B,1)

Notes

  • Beta is defined on (0,1). We clamp actions for numerical stability.

  • If bounds are provided, we apply a change-of-variables correction to log_prob.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

entropy(latent: Tensor) Tensor[source]#

Compute the entropy of the distribution defined by the latent representation.

Parameters:

latent (Tensor) – The input to the head, typically the output of a backbone network.

Returns:

The entropy of the distribution.

Return type:

entropy (Tensor)

log_prob(latent: Tensor, action: Tensor) Tensor[source]#

Compute the log probability of a given action under the distribution defined by the latent representation.

Parameters:
  • latent (Tensor) – The input to the head, typically the output of a backbone network.

  • action (Tensor) – The action for which to compute the log probability.

Returns:

The log probability of the given action.

Return type:

log_prob (Tensor)

sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Tensor][source]#

Sample an action from the distribution given the latent representation.

Parameters:
  • latent (Tensor) – The input to the head, typically the output of a backbone network.

  • deterministic (bool) – Whether to sample deterministically (e.g., take the mean) or stochastically.

Returns:

The sampled action. log_prob (Tensor): The log probability of the sampled action. entropy (Tensor): The entropy of the distribution.

Return type:

action (Tensor)

class prt_rl.common.components.heads.heads.CategoricalHead(latent_dim: int, num_actions: int)[source]#

Categorical actor head for discrete action spaces.

Contract:
  • sample(latent, deterministic) -> (action, log_prob, entropy)

    action: (B,) int64 log_prob: (B,1) float entropy: (B,1) float

  • log_prob(latent, action) -> (B,1)

  • entropy(latent) -> (B,1)

Notes

  • Uses torch.distributions.Categorical(logits=…)

  • Deterministic action is argmax over logits.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

entropy(latent: Tensor) Tensor[source]#

Compute the entropy of the distribution defined by the latent representation.

Parameters:

latent (Tensor) – The input to the head, typically the output of a backbone network.

Returns:

The entropy of the distribution.

Return type:

entropy (Tensor)

get_logits(latent: Tensor) Tensor[source]#

Get the raw logits output by the head for inspection or auxiliary losses.

Parameters:

latent (Tensor) – Latent state representation of shape (B, latent_dim).

Returns:

Logits of shape (B, num_actions).

Return type:

Tensor

log_prob(latent: Tensor, action: Tensor) Tensor[source]#

action expected shape: (B,) (dtype long) or (B,1) which will be squeezed.

sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Tensor][source]#

Sample an action from the distribution given the latent representation.

Parameters:
  • latent (Tensor) – The input to the head, typically the output of a backbone network.

  • deterministic (bool) – Whether to sample deterministically (e.g., take the mean) or stochastically.

Returns:

The sampled action. log_prob (Tensor): The log probability of the sampled action. entropy (Tensor): The entropy of the distribution.

Return type:

action (Tensor)

class prt_rl.common.components.heads.heads.ContinuousHead(latent_dim: int, action_dim: int)[source]#

Continuous action head that outputs raw actions for continuous action spaces.

This head is typically used in deterministic policy algorithms where the policy directly outputs continuous action values without sampling from a distribution.

Notes

  • The output layer is linear (no activation), producing raw action values.

  • The expected output shape is (B, action_dim), where B is the batch size and action_dim is the dimensionality of the action space.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(latent: Tensor) Tensor[source]#

Compute the continuous action output.

Parameters:

latent (Tensor) – Latent state representation of shape (B, latent_dim).

Returns:

Continuous action output of shape (B, action_dim).

Return type:

Tensor

class prt_rl.common.components.heads.heads.DecisionHead(latent_dim: int, action_dim: int, *, decision_function: DecisionFunction = <prt_rl.common.decision_functions.Greedy object>)[source]#

Decision head for discrete action spaces that outputs raw logits and applies a decision function.

This head is typically used in deterministic policy algorithms where the policy directly outputs logits for discrete actions without sampling from a distribution.

Notes

  • The output layer is linear (no activation), producing raw logits.

  • The expected output shape is (B, action_dim), where B is the batch size and action_dim is the number of discrete actions.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sample(latent: Tensor, deterministic: bool = False) Tensor[source]#

Compute the decision logits output.

Parameters:

latent (Tensor) – Latent state representation of shape (B, latent_dim).

Returns:

Actions of shape (B,).

Return type:

Tensor

class prt_rl.common.components.heads.heads.DuelingHead(latent_dim: int, num_actions: int)[source]#

Dueling network head for discrete action spaces.

Implements the dueling architecture where separate streams estimate the state-value function V(s) and the advantage function A(s, a). The Q-values are computed as:

Q(s, a) = V(s) + A(s, a) - mean_a’ A(s, a’)

This head is commonly used in DQN variants to improve value estimation.

Parameters:
  • latent_dim (int) – Dimension of the latent state representation.

  • num_actions (int) – Number of discrete actions.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(latent: Tensor) Tensor[source]#

Compute the Q-value estimates for all actions.

Parameters:

latent (Tensor) – Latent state representation of shape (B, latent_dim).

Returns:

Q-value estimates of shape (B, num_actions).

Return type:

Tensor

class prt_rl.common.components.heads.heads.GaussianHead(latent_dim: int, action_dim: int, *, log_std_init: float = -0.5, min_log_std: float = -20.0, max_log_std: float = 2.0, use_rsample: bool = False)[source]#

Diagonal Gaussian actor head for continuous action spaces.

Parameterization:
  • mean = Linear(latent)

  • log_std is a learned parameter vector (state-independent)

Contract:
  • sample(latent, deterministic) -> (action, log_prob, entropy)

    action: (B, act_dim) log_prob: (B,1) summed over action dims entropy: (B,1) summed over action dims

  • log_prob(latent, action) -> (B,1)

  • entropy(latent) -> (B,1)

Notes

  • Uses torch.distributions.Normal(mean, std) with diagonal independence.

  • We sum log_prob/entropy over action dims inside the head to keep callers DRY.

  • If you want state-dependent std, replace the Parameter with another Linear head.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

entropy(latent: Tensor) Tensor[source]#

Compute the entropy of the distribution defined by the latent representation.

Parameters:

latent (Tensor) – The input to the head, typically the output of a backbone network.

Returns:

The entropy of the distribution.

Return type:

entropy (Tensor)

log_prob(latent: Tensor, action: Tensor) Tensor[source]#

action expected shape: (B, act_dim)

sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Tensor][source]#

Sample an action from the distribution given the latent representation.

Parameters:
  • latent (Tensor) – The input to the head, typically the output of a backbone network.

  • deterministic (bool) – Whether to sample deterministically (e.g., take the mean) or stochastically.

Returns:

The sampled action. log_prob (Tensor): The log probability of the sampled action. entropy (Tensor): The entropy of the distribution.

Return type:

action (Tensor)

class prt_rl.common.components.heads.heads.QValueHead(latent_dim: int, action_dim: int)[source]#

State–action value function head.

Implements a scalar Q-function Q(s, a) that maps a latent state representation and a continuous action vector to a single value estimate per state–action pair.

This head is typically used by off-policy algorithms such as SAC and TD3, where the critic evaluates specific actions rather than all actions at once.

Notes

  • The latent state and action are concatenated along the feature dimension.

  • The output layer is linear (no activation), as Q-values are unbounded.

  • This implementation assumes continuous action spaces. For discrete action spaces (e.g., DQN), a different head that outputs Q(s, ·) is used.

Parameters:
  • latent_dim (int) – Dimension of the latent state representation.

  • action_dim (int) – Dimension of the action vector.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(latent: Tensor, action: Tensor) Tensor[source]#

Compute the state–action value estimate.

Parameters:
  • latent (Tensor) – Latent state representation of shape (B, latent_dim).

  • action (Tensor) – Action tensor of shape (B, action_dim).

Returns:

Q-value estimates of shape (B, 1).

Return type:

Tensor

class prt_rl.common.components.heads.heads.TanhGaussianHead(latent_dim: int, action_dim: int, *, log_std_min: float = -20.0, log_std_max: float = 2.0, epsilon: float = 1e-06, action_low: Tensor | None = None, action_high: Tensor | None = None)[source]#

Squashed (tanh) diagonal Gaussian actor head, commonly used in SAC.

It parameterizes a Normal distribution in R^act_dim, samples with rsample(), then squashes via tanh to (-1, 1). Optionally scales/shifts to env bounds.

Key detail: log_prob must include the tanh “change of variables” correction:

log pi(a) = log N(u; mu, std) - sum log(1 - tanh(u)^2) where a = tanh(u)

Optional scaling to [low, high]:

a_env = scale * a + bias, scale=(high-low)/2, bias=(high+low)/2

adds another correction term:

log pi(a_env) = log pi(a) - sum log(scale)

API:
  • sample(latent, deterministic=False) -> (action, log_prob, info)

  • log_prob(latent, action) -> log_prob (B,1) [action is final env-scaled action]

  • entropy(…) is not analytic after tanh; SAC typically uses -log_prob as entropy term proxy.

Notes

  • Uses state-dependent mean and log_std by default (two linear layers).

  • If you want state-independent log_std, replace log_std_layer with nn.Parameter.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

entropy_proxy(latent: Tensor, action: Tensor) Tensor[source]#

SAC typically uses -log_prob as an entropy term (up to expectations). This returns (B,1) = -log pi(a).

You might prefer calling this neg_log_prob depending on your style.

log_prob(latent: Tensor, action: Tensor) Tensor[source]#

Computes log pi(action) for a given final action (env-scaled if bounds were provided).

action:
  • if bounds are None: expected in (-1,1)

  • if bounds are provided: expected in [low, high]

sample(latent: Tensor, deterministic: bool = False) Tuple[Tensor, Tensor, Dict[str, Tensor]][source]#
Returns:

(B, act_dim) final action (env-scaled if bounds provided) log_prob: (B, 1) log pi(action) info: dict containing:

  • ”pre_tanh”: u (B, act_dim)

  • ”tanh”: a in (-1,1) (B, act_dim)

  • ”mean”: mean (B, act_dim)

  • ”log_std”: log_std (B, act_dim)

Return type:

action

class prt_rl.common.components.heads.heads.ValueHead(latent_dim: int)[source]#

State-value function head.

Implements a scalar value function V(s) that maps a latent state representation (typically produced by an encoder and/or backbone network) to a single value estimate per state.

This head is commonly used by actor–critic algorithms such as PPO, A2C, and A3C.

Notes

  • The output layer is linear (no activation), as value functions are generally unbounded.

  • The expected output shape is (B, 1), where B is the batch size.

Parameters:

latent_dim (int) – Dimension of the latent state representation.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(latent: Tensor) Tensor[source]#

Compute the state-value estimate.

Parameters:

latent (Tensor) – Latent state representation of shape (B, latent_dim).

Returns:

Value estimates of shape (B, 1).

Return type:

Tensor