Source code for prt_rl.common.progress_bar
from tqdm import tqdm
[docs]
class ProgressBar:
"""
Training Progress Bar
Args:
total_steps (int): Total number of environment steps that will be collected.
"""
def __init__(self, total_steps):
self.pbar = tqdm(total=total_steps)
self.prev_iteration = 0
def update(self, current_step: int, desc: str) -> None:
self.pbar.set_description(desc, refresh=False)
self.pbar.update(n=current_step - self.prev_iteration)
self.prev_iteration = current_step