Source code for prt_rl.common.components.networks.qcritic

from torch import nn

[docs] class QCritic(nn.Module): """ QCritic is a neural network module for estimating Q-values in reinforcement learning. This class composes a feature extraction network and a critic head to compute Q-values given observations and actions. It is typically used in actor-critic or value-based RL algorithms. Args: network (nn.Module): Feature extractor network that processes observations. critic_head (nn.Module): Head network that takes features and actions to output Q-values. """
[docs] def __init__(self, network: nn.Module, critic_head: nn.Module): """ Initialize the QCritic module. Args: network (nn.Module): Feature extractor for observations. critic_head (nn.Module): Module that computes Q-values from features and actions. """ super().__init__() self.network = network self.critic_head = critic_head
[docs] def forward(self, obs, action): """ Forward pass to compute Q-values from observations and actions. Args: obs: Input observations (tensor or compatible type for network). action: Actions to evaluate (tensor or compatible type for critic_head). Returns: Q-values estimated by the critic (tensor). """ features = self.network(obs) q = self.critic_head(features, action) # adjust signature to your head return q