Source code for prt_rl.model_based.planners.rollout

import torch
from typing import Any

[docs] def rollout_action_sequence( model_config: Any, model_fcn: callable, initial_state: torch.Tensor, action_sequence: torch.Tensor ): """ Rollout a sequence of actions using the provided model starting from the initial state. Args: model: The model used to predict the next state. initial_state (torch.Tensor): The initial state of the environment. action_sequence (torch.Tensor): A sequence of actions to be taken. Returns: Dict[str, torch.Tensor]: A dictionary containing the states, actions, and next_states encountered during the rollout. - 'state': Tensor of shape (B, T, state_dim) containing the states at each time step. - 'action': Tensor of shape (B, T, action_dim) containing the actions taken at each time step. - 'next_state': Tensor of shape (B, T, state_dim) containing the next states at each time step. """ B, T, _ = action_sequence.shape # Repeat the initial state to match the number of batch action sequences - (B, state_dim) state = initial_state.repeat(B, 1) states = [state] actions = [] next_states = [] for t in range(T): # Get the action at time step t - (B, action_dim) action = action_sequence[:, t, :] actions.append(action) # Predict the next state using the model next_state = model_fcn(model_config, state, action) next_states.append(next_state) # Update the current state state = next_state return { 'state': torch.stack(states, dim=1), # (B, T, state_dim) 'action': torch.stack(actions, dim=1), # (B, T, action_dim) 'next_state': torch.stack(next_states, dim=1) # (B, T, state_dim) }