Source code for prt_rl.common.components.networks.mlp

from torch import nn, Tensor
from typing import Optional, Tuple
from prt_rl.common.utils import to_activation


[docs] def build_mlp( input_dim: int, hidden_sizes: Tuple[int, ...], activation: str, output_dim: Optional[int] = None, output_activation: Optional[str] = None, ) -> nn.Sequential: layers = [] prev = input_dim for h in hidden_sizes: layers += [nn.Linear(prev, h), to_activation(activation)] prev = h if output_dim is not None: layers += [nn.Linear(prev, output_dim)] prev = output_dim if output_activation is not None: layers += [to_activation(output_activation)] return nn.Sequential(*layers)