Source code for prt_rl.model_based.planners.cross_entropy

import torch
from torch.distributions import Normal
from typing import Any, Callable, Optional
from prt_rl.model_based.planners.rollout import rollout_action_sequence

[docs] def temporal_smooth( x: torch.Tensor, method: str = "none", rho: float = 0.9, kernel_size: int = 0, ) -> torch.Tensor: """ Apply simple temporal smoothing along the horizon dimension. Args: x : torch.Tensor Tensor of shape (N, H, dA) where: - N: number of sequences/samples - H: planning horizon - dA: action dimension method : {"none", "ou", "conv"}, default: "none" - "none": return x unchanged. - "ou": Exponential moving average (Ornstein–Uhlenbeck-like) smoothing: out[:, t] = rho * out[:, t-1] + (1 - rho) * x[:, t] - "conv": 1D convolution with a Gaussian-ish kernel of length `kernel_size`. rho : float, default: 0.9 Smoothing factor for "ou". Higher means smoother (more inertia). kernel_size : int, default: 0 Kernel length for "conv". Must be >= 3 to have an effect. Returns: torch.Tensor Smoothed tensor with the same shape as `x` (N, H, dA). Notes: - Smoothing should be applied in **U-space** for tanh bound mode (preferred), and in **A-space** for clip mode. - For "conv", edges are handled with 'replicate' padding. """ N, H, da = x.shape if method == 'ou': smooth_x = x.clone() for t in range(1, H): smooth_x[:, t] = rho * smooth_x[:, t-1] + (1 - rho) * x[:, t] return smooth_x elif method == 'conv': t = torch.arange(kernel_size, device=x.device, dtype=x.dtype) - (kernel_size - 1) / 2 kernel = torch.exp(-0.5 * (t / (0.25 * kernel_size))**2) kernel = (kernel / kernel.sum()).view(1, 1, -1) # (1,1,k) xt = x.permute(0, 2, 1).reshape(N * da, 1, H) # (N*dA,1,H) pad = (kernel_size // 2, kernel_size // 2) xt = torch.nn.functional.pad(xt, pad, mode='replicate') yt = torch.nn.functional.conv1d(xt, kernel) # (N*dA,1,H) return yt.view(N, da, H).permute(0, 2, 1).contiguous() else: return x
[docs] class CrossEntropyMethodPlanner: """ Cross-Entropy Method (CEM) planner for continuous control with support for tanh-squash (U-space) and clip (A-space) bounding strategies. Workflow per planning call -------------------------- 1) Initialize or warm-start the sequence distribution (shape (H, dA)). 2) Repeat for K iterations: a) Sample N sequences (N,H,dA) using the bound strategy. b) Roll out through the (known or learned) dynamics model. c) Compute reward for each sequence, pick top M (elites). d) Refit the distribution from elites (in the proper space). 3) Return the first action of the best-scoring elite. Parameters ---------- action_mins, action_maxs : torch.Tensor Action bounds with shape (dA, 1) (or (dA,)). Broadcasted internally. num_action_sequences : int, default: 100 N, number of sequences sampled per iteration. planning_horizon : int, default: 10 H, number of steps in each sequence. num_elites : int, default: 10 M, number of top sequences used for refit. num_iterations : int, default: 5 K, number of CEM refinement iterations per plan call. use_smoothing : bool, default: False If True, apply temporal smoothing (OU) inside the bound strategy (U-space for tanh, A-space for clip). use_clipping : bool, default: False If True, use ClipBound; otherwise use TanhSquashBound. tau : float or None, default: H/3 Time constant for std decay schedule. beta : float, default: 0.2 Long-horizon std floor fraction for the decay schedule. device : {"cpu","cuda",...}, default: "cpu" Device for internal tensors. Notes ----- - This implementation assumes **higher reward is better**. If you use costs, either flip the sign or use `largest=False` in `topk`. - `rollout_action_sequence(model_config, model_fcn, state, actions)` must return a dict with `'state'`, `'action'`, and `'next_state'` batches consistent with shapes (N, H, ·). """ def __init__(self, action_mins: torch.Tensor, action_maxs: torch.Tensor, num_action_sequences: int = 100, planning_horizon: int = 10, num_elites: int = 10, num_iterations: int = 5, use_smoothing: bool = False, use_clipping: bool = False, tau: float | None = None, beta: float = 0.2, device: str = 'cpu' ) -> None: assert action_mins.shape == action_maxs.shape, "Action mins and maxs must have the same shape." assert num_elites <= num_action_sequences, "Number of elites must be less than or equal to number of action sequences." assert num_iterations > 0, "Number of iterations must be greater than 0." self.planning_horizon = planning_horizon self.num_action_sequences = num_action_sequences self.num_elites = num_elites self.num_iterations = num_iterations self.use_smoothing = use_smoothing self.use_clipping = use_clipping self.tau = tau if tau is not None else planning_horizon / 3 self.beta = beta self.device = torch.device(device) # Move action bound tensors to the correct device and compute the scale and bias for rescaling self.action_mins = action_mins.to(self.device) self.action_maxs = action_maxs.to(self.device) self.action_scale = (self.action_maxs - self.action_mins) / 2.0 self.action_bias = (self.action_maxs + self.action_mins) / 2 if self.use_clipping: self.bound_strategy = ClipBound() else: self.bound_strategy = TanhSquashBound() self.distribution = None self.elites = None
[docs] def plan(self, model_fcn: Callable, model_config: Any, reward_fcn: Callable, state: torch.Tensor ) -> torch.Tensor: """ Run one CEM planning call and return the first action to execute. Parameters ---------- model_fcn : Callable One-step dynamics function (batched) used by the rollout utility. model_config : Any Additional config passed to your rollout helper. reward_fcn : Callable Function computing rewards from rollout dict; returns (N,) reward per sequence. state : torch.Tensor Current state (batching left to caller/rollout helper). Returns ------- torch.Tensor First action of the best elite sequence, shape (1, dA). Notes ----- - Sampling returns **A-space** actions in both strategies. - Refit is done in U or A space depending on the bound strategy. """ # Initialize the prior distribution if self.distribution is None or self.elites is None: self.distribution = self.bound_strategy.cold_start(H=self.planning_horizon, a_mins=self.action_mins, a_maxs=self.action_maxs, beta=self.beta, tau=self.tau ) else: self.distribution = self.bound_strategy.warm_start(elites=self.elites, a_mins=self.action_mins, a_maxs=self.action_maxs, widening_factor=1.3, std_min=0.5 ) for _ in range(self.num_iterations): # Sample new action sequences - (N, H, da) action_sequences = self.bound_strategy.sample(self.distribution, torch.Size((self.num_action_sequences,)), self.action_mins, self.action_maxs) # Evaluate action sequences using the model and reward function rollout = rollout_action_sequence(model_config, model_fcn, state, action_sequences) rewards = reward_fcn(rollout['state'], rollout['action'], rollout['next_state']) # Pick the top M elites _, elite_indices = torch.topk(rewards, self.num_elites, largest=True) self.elites = action_sequences[elite_indices] # Refit the distribution to the elites self.distribution = self.bound_strategy.refit(self.elites, self.action_mins, self.action_maxs) # Return the first action from the best action sequence return self.elites[0, 0, :].unsqueeze(0)
[docs] class TanhSquashBound: """ Tanh squashing strategy (recommended default). The underlying distribution lives in **U-space** (unbounded). Sampling: U ~ Normal(mu_u, sigma_u) (shape (H, dA)) A = (tanh(U) + 1)/2 * (a_max - a_min) + a_min (shape (N, H, dA), in bounds) Refit must be done in **U-space**. We therefore convert A-space elites back to U-space via an atanh-like transform and compute the new Normal in U. All methods here are stateless utilities; the planner owns the current distribution. Shapes ------ - a_mins, a_maxs : (dA, 1) or (dA,) - distribution.loc/scale : (H, dA) - samples / elites : (N, H, dA) """
[docs] @staticmethod def sample( distribution: Normal, shape: torch.Size, a_mins: torch.Tensor, a_maxs: torch.Tensor, smoothing: str = "ou", rho: float = 0.5, kernel_size: int = 0, ) -> torch.Tensor: """ Sample actions from a U-space Normal, optionally smooth in U, and squash to A-space. Parameters ---------- distribution : Normal U-space Normal with loc/scale shape (H, dA). shape : torch.Size Leading sample shape, e.g., (N,) to get (N,H,dA). a_mins, a_maxs : torch.Tensor Bounds, shape (dA,1) or (dA,). smoothing : {"none","ou","conv"}, default: "none" Smoothing applied in **U-space** before squashing. rho : float, default: 0.9 OU smoothing factor. kernel_size : int, default: 0 Convolution kernel length (only used if smoothing="conv"). Returns ------- torch.Tensor Actions in A-space with shape (N, H, dA). """ # Sample distribution in U-space with shape (N, H, da) u_actions = distribution.rsample(shape) # Apply temporal smoothing to the u-space actions u_smooth = temporal_smooth(u_actions, method=smoothing, rho=rho, kernel_size=kernel_size) # Convert action from U-space to action space a_actions = TanhSquashBound._from_u_space(u_smooth, a_mins, a_maxs) return a_actions
[docs] @staticmethod def refit( elites: torch.Tensor, a_mins: torch.Tensor, a_maxs: torch.Tensor, std_min: float = 1e-6, ) -> Normal: """ Refit the U-space Normal from A-space elites. Parameters ---------- elites : torch.Tensor Elite action sequences in A-space, shape (N_e, H, dA). a_mins, a_maxs : torch.Tensor Bounds with shape (dA,1) or (dA,). std_min : float, default: 1e-6 Std floor. Returns ------- Normal Updated U-space Normal with loc/scale (H, dA). """ # Convert elite actions to U-space u_elites = TanhSquashBound._to_u_space(elites, a_mins, a_maxs) mean_u = u_elites.mean(dim=0) std_u = u_elites.std(dim=0, unbiased=False).clamp_min(std_min) return Normal(loc=mean_u, scale=std_u)
[docs] @staticmethod def cold_start( H: int, a_mins: torch.Tensor, a_maxs: torch.Tensor, beta: float, tau: float, sigma_u0: float = 0.6, std_min: float = 1e-6, ) -> Normal: """ Initialize the U-space Normal used for tanh squashing. Parameters ---------- H : int Planning horizon. a_mins, a_maxs : torch.Tensor Bounds with shape (dA,1) or (dA,). Used only for dtype/device; U-space init is centered at zero regardless of bounds. beta : float Long-horizon std decay floor in [0,1]. Effective std(t) = (beta + (1-beta) * exp(-t/tau)) * sigma_u0. tau : float Decay time constant (in steps). sigma_u0 : float, default: 0.6 Initial U-space std per dimension at t=0 (before decay). Values in [0.4, 1.0] are robust. std_min : float, default: 1e-6 Absolute std floor for numerical stability. Returns ------- torch.distributions.Normal U-space Normal with loc/scale shape (H, dA). """ # Get the action dimension da = a_mins.shape[0] # Compute the initial mean and standard deviation # Center of the action box center = ((a_mins + a_maxs) / 2.0).squeeze(-1) # (dA,) sigma_0 = ((a_maxs - a_mins) / 2.0).squeeze(-1) # (dA,) t = torch.arange(H, device=a_mins.device, dtype=center.dtype) # (H,) decay = beta + (1-beta) * torch.exp(-t / tau) # (H,) # U-space Gaussian (pre-squash) # Choose a sigma_0 [0.4, 1.0] for robustness in U-space mean_u = torch.zeros(H, da, device=a_mins.device, dtype=center.dtype) # (H, dA) sigma_0 = torch.full_like(center, sigma_u0) # (dA,) std_u = decay.unsqueeze(1) * sigma_0.unsqueeze(0) # (H, dA) return Normal( loc=mean_u, scale=std_u.clamp_min(std_min) )
[docs] @staticmethod def warm_start( elites: torch.Tensor, a_mins: torch.Tensor, a_maxs: torch.Tensor, widening_factor: float = 1.3, std_min: float = 1e-6, ) -> Normal: """ Warm-start the U-space Normal from previous elites in A-space. Steps ----- 1) Convert elites A -> U, compute mean/std across elite batch. 2) Shift μ_u, σ_u forward by one time-step. 3) Tail: set last μ_u to 0 (center in U), keep last σ_u. 4) Widen σ_u[0] by `widening_factor` to retain agility. 5) (Optional) Anchor μ_u[0] toward the last executed action (converted to U) with convex blend μ_u[0] ← λ * μ_u[0] + (1-λ) * u_exec. Parameters ---------- elites : (N_e, H, dA) Elite **A-space** action sequences from previous iteration. a_mins, a_maxs : (dA,1) or (dA,) Bounds. widening_factor : float, default: 1.3 Multiplier for σ_u at t=0 after shift. std_min : float, default: 1e-6 Std floor. executed_action : (1, dA) or (dA,), optional Last executed action to anchor to (A-space). anchor_lambda : float, default: 0.8 Blend weight; larger = rely more on shifted mean. Returns ------- Normal Warm-started U-space Normal (H, dA). """ # Convert elite actions to U-space u_elites = TanhSquashBound._to_u_space(elites, a_mins, a_maxs) mean = torch.mean(u_elites, dim=0) standard_dev = torch.std(u_elites, dim=0, unbiased=False) # Shift the mean and std to the next time step shifted_mean = torch.zeros_like(mean, device=mean.device, dtype=mean.dtype) shifted_mean[:-1] = mean[1:] shifted_std = torch.zeros_like(standard_dev, device=standard_dev.device, dtype=standard_dev.dtype) shifted_std[:-1] = standard_dev[1:] # Add tail value for the last time step shifted_mean[-1].zero_() shifted_std[-1] = standard_dev[-1] # Widen the standard deviation to encourage exploration shifted_std[0] = (shifted_std[0] * widening_factor).clamp_min_(std_min) return Normal( loc=shifted_mean, scale=shifted_std.clamp_min(1e-6) )
@staticmethod def _to_u_space(a_actions, a_mins, a_maxs, epsilon=1e-6): """ Convert bounded actions A in [a_min, a_max] to U-space via atanh. Parameters ---------- a : (N,H,dA) a_mins, a_maxs : (dA,1) or (dA,) eps : float Safety margin to avoid infinities at ±1 after normalization. Returns ------- torch.Tensor U-space tensor (N,H,dA). """ y = (2*(a_actions - a_mins) / (a_maxs - a_mins) - 1).clamp(-1 + epsilon, 1 - epsilon) # atanh(y) = 0.5*log((1+y)/(1-y)) return 0.5 * torch.log1p(y) - 0.5 * torch.log1p(-y) @staticmethod def _from_u_space(u: torch.Tensor, a_mins: torch.Tensor, a_maxs: torch.Tensor) -> torch.Tensor: """ Map U-space to A-space via tanh and affine bounds. Parameters ---------- u : (N,H,dA) a_mins, a_maxs : (dA,1) or (dA,) Returns ------- torch.Tensor Bounded A-space actions (N,H,dA). """ y = torch.tanh(u) return (y + 1) / 2 * (a_maxs - a_mins) + a_mins
[docs] class ClipBound: """ Hard clipping strategy (simple & useful for bang-bang optima). The distribution lives and is refit in **A-space**. Sampling: A ~ Normal(mu_a, sigma_a) (shape (H, dA)) A := clamp(A, [a_min, a_max]) Shapes ------ - a_mins, a_maxs : (dA, 1) or (dA,) - distribution.loc/scale : (H, dA) - samples / elites : (N, H, dA) """
[docs] @staticmethod def sample( distribution: Normal, shape: torch.Size, a_mins: torch.Tensor, a_maxs: torch.Tensor, smoothing: str = "ou", rho: float = 0.5, kernel_size: int = 0, ) -> torch.Tensor: """ Sample actions from an A-space Normal, optionally smooth in A, then clamp to bounds. Returns ------- torch.Tensor A-space actions with shape (N, H, dA), in-bounds if `clamp=True`. """ # Sample distribution in A-space with shape (N, H, da) a_actions = distribution.rsample(shape) # Apply temporal smoothing to the actions a_smooth = temporal_smooth(a_actions, method=smoothing, rho=rho, kernel_size=kernel_size) # Clip to the action bounds actions = torch.clamp(a_smooth, a_mins, a_maxs) return actions
@staticmethod def refit(elites: torch.Tensor, a_mins: torch.Tensor, a_maxs: torch.Tensor, std_min: float = 1e-6) -> Normal: mean = elites.mean(dim=0) standard_dev = elites.std(dim=0, unbiased=False).clamp_min(std_min) return Normal(loc=mean, scale=standard_dev)
[docs] @staticmethod def cold_start( H: int, a_mins: torch.Tensor, a_maxs: torch.Tensor, beta: float, tau: float, std_min: float = 1e-6, ) -> Normal: """ Initialize an A-space Normal with center-of-box mean and decayed half-span std. sigma_a(t) = (beta + (1 - beta) * exp(-t/tau)) * (a_max - a_min)/2 Returns ------- Normal A-space Normal (H, dA). """ # Get the action dimension da = a_mins.shape[0] # Compute the initial mean and standard deviation # Center of the action box center = ((a_mins + a_maxs) / 2.0).squeeze(-1) # (dA,) sigma_0 = ((a_maxs - a_mins) / 2.0).squeeze(-1) # (dA,) t = torch.arange(H, device=a_mins.device, dtype=center.dtype) # (H,) decay = beta + (1-beta) * torch.exp(-t / tau) # (H,) mean = center.unsqueeze(0).expand(H, da) # (H, dA) standard_dev = decay.unsqueeze(1) * sigma_0.unsqueeze(0) # (H, dA) # Initialize a Time-varying Diagonal Gaussian distribution return Normal( loc=mean, scale=standard_dev.clamp_min(std_min) )
[docs] @staticmethod def warm_start( elites: torch.Tensor, a_mins: torch.Tensor, a_maxs: torch.Tensor, widening_factor=1.5, std_min=1e-6 ) -> Normal: """ Warm-start the A-space Normal from previous elites. Steps ----- 1) Compute mean/std across elites. 2) Shift mu_a, sigma_a forward by one time-step. 3) Tail: - "repeat": use last mean/std - "center": set last mean to center of box (keeps last std) 4) Widen sigma_a[0] by `widening_factor`. 5) (Optional) Anchor mu_a[0] toward last executed action with convex blend. Args: elites : (N_e, H, dA) Elite **A-space** action sequences from previous iteration. a_mins, a_maxs : (dA,1) or (dA,) Bounds are not used. These are part of the interface but not required. widening_factor : float, default: 1.5 Multiplier for sigma_a at t=0 after shift. std_min : float, default: 1e-6 Std floor. Returns: Normal Warm-started A-space Normal (H, dA). """ mean = torch.mean(elites, dim=0) standard_dev = torch.std(elites, dim=0, unbiased=False) # Shift the mean and std to the next time step shifted_mean = torch.zeros_like(mean, device=mean.device, dtype=mean.dtype) shifted_mean[:-1] = mean[1:] shifted_std = torch.zeros_like(standard_dev, device=standard_dev.device, dtype=standard_dev.dtype) shifted_std[:-1] = standard_dev[1:] # Add tail value as the previous last time step shifted_mean[-1] = mean[-1] shifted_std[-1] = standard_dev[-1] # Widen the standard deviation to encourage exploration shifted_std[0] = (shifted_std[0] * widening_factor).clamp_min_(std_min) return Normal( loc=shifted_mean, scale=shifted_std.clamp_min(1e-6) )