Source code for prt_sim.jhu.plot
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import t
from typing import List
from prt_sim.jhu.bandits import KArmBandits
[docs]
def plot_bandit_probabilities(env: KArmBandits) -> None:
"""
Plots the mean and variance of the bandit probabilities.
Args:
env (KArmBandits): bandits environment
"""
probs = env.bandit_probs
plt.errorbar(
np.arange(len(probs)),
probs,
yerr=np.ones(len(probs)),
fmt='o',
linewidth=2,
capsize=6,
)
plt.xlabel("Action")
plt.ylabel("Reward Distribution")
plt.title(f"{len(probs)}-armed Testbed")
[docs]
def plot_bandit_rewards(rewards: np.ndarray) -> None:
"""
Plots the rewards received by the agent(s) playing the bandits game.
Args:
rewards (np.ndarray): rewards received by the agent(s) with shape (# agents, # episodes)
"""
if rewards.shape[0] == 1:
plt.plot(rewards[0])
else:
# Compute mean of rewards
means = np.mean(rewards, axis=0)
stds = np.std(rewards, axis=0)
# Compute confidence interval of rewards
t_critical = t.ppf(0.975, df=rewards.shape[0] - 1)
ci_margin = t_critical * (stds / np.sqrt(rewards.shape[0]))
ci_upper = means + ci_margin
ci_lower = means - ci_margin
plt.plot(np.arange(rewards.shape[-1]), means)
plt.fill_between(np.arange(rewards.shape[-1]), ci_lower, ci_upper, alpha=0.20)
plt.xlabel('Steps')
plt.ylabel('Average Rewards')
plt.title("Average Agent Rewards")
[docs]
def plot_bandit_percent_optimal_action(optimal_bandits: np.ndarray, actions: np.ndarray) -> None:
"""
Creates a plot of the percentage of optimal actions over the training episodes.
Args:
optimal_bandits (np.ndarray): array of optimal bandit indexes
actions (np.ndarray): actions chosen by the agent(s) with shape (# agents, # episodes)
"""
# Sum the number of times the optimal action was chosen
optimal_actions = np.sum((actions == optimal_bandits[:, np.newaxis]), axis=0).astype(float)
# Divide the count by the number of runs in the step
optimal_action_percent = optimal_actions / float(actions.shape[0]) * 100.0
plt.plot(optimal_action_percent)
plt.xlabel('Steps')
plt.ylabel('% Optimal action')
plt.title("Average Agent Optimal Action Selection")