Source code for prt_rl.common.plot

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import t
from typing import Union, Dict


[docs] def plot_scalar_metric( metric_name: str, metric_values: Union[Dict, np.ndarray], max_value: float=None ) -> None: """ Plots a confidence plot of the scalar metric. metric = {'metric_name': Support confidence max and min chopping Support color setting Support xlim and ylim setting Args: metric_name (str): name of the metric metric_values (dict): metric values max_value (float, optional): maximum value of the metric. Plots a horizontal line if maximum is provided. Defaults to None. Examples: """ def _plot_metric(ax, label, data): if len(data.shape) == 1: ax.plot(data, label=label) elif len(data.shape) == 2: if data.shape[0] == 1: ax.plot(data[0], label=label) else: # Compute metric mean and standard deviation means = np.mean(data, axis=0) stds = np.std(data, axis=0) # Compute confidence interval t_critical = t.ppf(0.975, df=data.shape[0] - 1) ci_margin = t_critical * (stds / np.sqrt(data.shape[0])) ci_upper = means + ci_margin ci_lower = means - ci_margin ax.plot(np.arange(data.shape[-1]), means, label=label) ax.fill_between(np.arange(data.shape[-1]), ci_lower, ci_upper, alpha=0.20) else: raise ValueError("Data must be 1D or 2D array") fig, ax = plt.subplots() # Handle if the metric values are provided as a dictionary of numpy arrays or just a single numpy array. if isinstance(metric_values, np.ndarray): _plot_metric(ax, "", metric_values) elif isinstance(metric_values, dict): for label, data in metric_values.items(): _plot_metric(ax, label, data) plt.legend() else: raise ValueError("Metric values must be a dict of numpy arrays or a numpy array") # Plot horizontal line for maximum value if max_value is not None: ax.axhline(y=max_value, color='k', linestyle='--') # Add plot labels ax.set_title(metric_name) ax.set_xlabel('Episodes') ax.set_ylabel(metric_name)