Source code for prt_rl.common.runners
from typing import Optional, List
from prt_rl.env.interface import EnvironmentInterface
from prt_rl.common.recorders import Recorder
from prt_rl.common.visualizers import Visualizer
from prt_rl.agent import Agent
from prt_rl.common.policies.interface import Policy
[docs]
def watch(env: EnvironmentInterface, policy: Policy, num_episodes: int = 1) -> None:
"""
Watch a trained RL agent in a gym environment.
Args:
env: The environment to run the agent in.
policy: The RL policy to use for acting in the environment.
"""
episode_rewards = []
for i in range(num_episodes):
obs, info = env.reset()
done = False
total_reward = 0.0
while not done:
action, _ = policy.act(obs, deterministic=True)
obs, reward, done, info = env.step(action)
total_reward += reward
print(f"Episode {i} Total reward: {total_reward.cpu().item()}")
episode_rewards.append(total_reward.cpu().item())
avg_reward = sum(episode_rewards) / num_episodes
print(f"Average reward over {num_episodes} episodes: {avg_reward}")
[docs]
class Runner:
"""
A runner executes an agent in an environment. It simplifies the process of evaluating agents that have been trained.
The runner assumes the rgb_array is in the info dictioanary and has shape (num_envs, channel, height, width).
.. note::
To use the visualizer, the environment wrapper render mode must be set to 'rgb_array'.
Args:
env (EnvironmentInterface): the environment to run the agent in
agent (BaseAgent): Agent to be executed in the environment
recorders (Optional[List[Recorder]]): List of recorders to record the experience and info during the run
visualizer (Optional[Visualizer]): Visualizer to show the environment frames during the run
"""
def __init__(self,
env: EnvironmentInterface,
agent: Agent,
recorders: Optional[List[Recorder]] = None,
visualizer: Optional[Visualizer] = None,
) -> None:
self.env = env
self.agent = agent
self.recorders = recorders or []
self.visualizer = visualizer
def run(self):
# Reset the environment and recorder
for r in self.recorders:
r.reset()
state, info = self.env.reset()
done = False
# Start visualizer and show initial frame
if self.visualizer is not None:
self.visualizer.start()
self.visualizer.show(info['rgb_array'][0])
for r in self.recorders:
r.record_info(info)
# Loop until the episode is done
while not done:
action = self.agent.act(state, deterministic=True)
next_state, reward, done, info = self.env.step(action)
# Record the environment frame
if self.visualizer is not None:
self.visualizer.show(info['rgb_array'][0])
for r in self.recorders:
r.record_info(info)
r.record_experience({
'state': state,
'action': action,
'reward': reward,
'next_state': next_state,
'done': done
})
state = next_state
if self.visualizer is not None:
self.visualizer.stop()
# Save the recording
for r in self.recorders:
r.close()
self.env.close()