Source code for prt_rl.model_based.models.dynamics.cartpole
from dataclasses import dataclass, asdict
from typing import Dict, Tuple
import torch
# -------- tensor cache --------
# key: (id(config), device_str, dtype_str) -> value: dict[str, torch.Tensor]
_TENSOR_CFG_CACHE: Dict[Tuple[int, str, str], Dict[str, torch.Tensor]] = {}
[docs]
def _tensorize_cfg_cached(config, device: torch.device, dtype: torch.dtype) -> Dict[str, torch.Tensor]:
key = (id(config), str(device), str(dtype))
entry = _TENSOR_CFG_CACHE.get(key)
if entry is None:
d = asdict(config)
entry = {k: torch.tensor(v, device=device, dtype=dtype) for k, v in d.items()}
_TENSOR_CFG_CACHE[key] = entry
return entry
[docs]
@dataclass
class CartPoleConfig:
"""
Configuration parameters for the Inverted Pendulum model. The default parameters are modeled after the InvertedPendulum-v5 environment from Gymnasium.
Attributes:
M (float): Mass of the cart (kg).
m (float): Mass of the pendulum (kg).
l (float): Length of the pendulum (m).
I (float): Moment of inertia of the pendulum (kg*m^2).
b_cart (float): Coefficient of friction for the cart linear damping (Nm/s).
b_pole(float): Coefficient of friction for the pendulum rotational damping (Nm/s).
g (float): Acceleration due to gravity (m/s^2).
dt (float): Time step for simulation (s).
F_scale (float): Scaling factor for the applied force.
"""
M: float = 10.472
m: float = 5.019
l: float = 1.0
I: float = 0.153
b_cart: float = 1.0
b_pole: float = 1.0
g: float = 9.81
dt: float = 0.02
F_scale: float = 100.0
[docs]
def cartpole_step(config: CartPoleConfig, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
"""
Compute the next state of the inverted pendulum given the current state and action using the equations of motion.
Args:
config (CartPoleConfig): Configuration parameters for the cart-pole system.
state (torch.Tensor): Current state tensor of shape (batch_size, 4) where each state is [x, theta, x_dot, theta_dot].
action (torch.Tensor): Action tensor of shape (batch_size, 1) representing the force applied to the cart.
Returns:
torch.Tensor: Next state tensor of shape (batch_size, 4).
"""
device, dtype = state.device, state.dtype
C = _tensorize_cfg_cached(config, device, dtype)
x, theta, x_dot, theta_dot = state[:, 0], state[:, 1], state[:, 2], state[:, 3]
F = action[:, 0] * C["F_scale"]
sin_theta = torch.sin(theta)
cos_theta = torch.cos(theta)
total_mass = C["M"] + C["m"]
pole_mass_length = C["m"] * C["l"]
# EOM (your form)
temp = (F + pole_mass_length * theta_dot**2 * sin_theta - C["b_cart"] * x_dot) / total_mass
theta_acc = (C["g"] * sin_theta - cos_theta * temp - C["b_pole"] * theta_dot / pole_mass_length) / \
(C["l"] * (4.0/3.0 - C["m"] * cos_theta**2 / total_mass))
x_acc = temp - pole_mass_length * theta_acc * cos_theta / total_mass
x_next = x + x_dot * C["dt"]
theta_next = theta + theta_dot * C["dt"]
x_dot_next = x_dot + x_acc * C["dt"]
theta_dot_next= theta_dot+ theta_acc * C["dt"]
return torch.stack([x_next, theta_next, x_dot_next, theta_dot_next], dim=1)