interface#

Policy interfaces used across algorithms.

Classes#

NeuralPolicy

Base class for torch-backed policies.

Policy

Runtime acting interface consumed by collectors.

TabularPolicy

Base class for tabular policies (non-Module).

class prt_rl.common.policies.interface.NeuralPolicy(*args: Any, **kwargs: Any)[source]#

Base class for torch-backed policies.

Implements the Policy protocol and adds utility methods for saving/loading and device management.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

__init__(*args: Any, **kwargs: Any) None#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstractmethod act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#

Return an action tensor and auxiliary policy outputs.

metadata() Dict[str, Any][source]#

Optionally save metadata alongside the policy. This is a no-op in the base class but can be overridden by subclasses.

class prt_rl.common.policies.interface.Policy(*args, **kwargs)[source]#

Runtime acting interface consumed by collectors.

act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#

Return an action tensor and auxiliary policy outputs.

class prt_rl.common.policies.interface.TabularPolicy(table: Tensor, decision_function: DecisionFunction)[source]#

Base class for tabular policies (non-Module).

abstractmethod act(obs: Tensor, deterministic: bool = False) Tuple[Tensor, Dict[str, Tensor]][source]#

Return an action tensor and auxiliary policy outputs.

snapshot() Dict[str, Any][source]#

Return a serializable snapshot.

Subclasses can extend this by snap = super().snapshot(); snap[…] = ….