Source code for prt_rl.common.schedulers

from abc import ABC, abstractmethod
import numpy as np
from typing import Union, Tuple, List

[docs] class ParameterScheduler(ABC): """ Abstract class for parameter scheduling. Args: obj (object): Object to which the parameter belongs parameter_name (str): Name of the parameter to schedule """ def __init__(self, obj: object, parameter_name: str ): self.obj = obj self.parameter_name = parameter_name def _resolve_target(self) -> tuple[object, str]: """ Resolve the target object and final attribute/key from a dotted parameter path. Examples: parameter_name='epsilon' -> (obj, 'epsilon') parameter_name='config.epsilon' -> (obj.config, 'epsilon') """ path_parts = self.parameter_name.split(".") target = self.obj for part in path_parts[:-1]: if isinstance(target, dict): target = target[part] else: target = getattr(target, part) return target, path_parts[-1] def _set_value(self, value: float) -> None: """ Set a scheduled value on either a direct field or a dotted path. """ target, name = self._resolve_target() if isinstance(target, dict): target[name] = value else: setattr(target, name, value)
[docs] def get_value(self) -> float: """ Get current value from either a direct field or a dotted path. """ target, name = self._resolve_target() if isinstance(target, dict): return target[name] return getattr(target, name)
[docs] @abstractmethod def update(self, current_step: int ) -> None: """ Returns the updated parameter value based on the current step number. Args: current_step (int): Current step number """ raise NotImplementedError
[docs] class LinearScheduler(ParameterScheduler): """ Linear schedule updates a parameter from a maximum value to a minimum value over a given number of episodes. Args: obj (object): Object to which the parameter belongs parameter_name (str): Name of the parameter to schedule start_value (float): Maximum value for the parameter end_value (Union[float, List[float]]): Minimum value for the parameter interval (Union[int , Tuple[int, int], List[Tuple[int, int]]]): Interval to schedule the parameter over. Can be a single integer, a tuple of integers, or a list of tuples. If a single integer is provided, the parameter will be scheduled over that many episodes. If a tuple is provided, the parameter will be scheduled over that range of episodes. If a list of tuples is provided, the parameter will be scheduled over each interval in the list. Raises: ValueError: If the interval is not greater than 0 or if the length of end_value and interval are not the same Example: .. python:: from prt_rl.common.schedulers import LinearScheduler from prt_rl.common.epsilon_greedy import EpsilonGreedy eg = EpsilonGreedy() # Schedule epsilon from 0.2 to 0.1 over 10 episodes starting from episode 0 s = LinearScheduler(obj=eg, parameter_name='epsilon', start_value=0.2, end_value=0.1, interval=10) # Schedule epsilon over an interval of (4, 10) from 0.2 to 0.1 s = LinearScheduler(obj=eg, parameter_name='epsilon', start_value=0.2, end_value=0.1, interval=(4, 10)) # Piecewise schedule epsilon over multiple intervals s = LinearScheduler(obj=eg, parameter_name='epsilon', start_value=1.0, end_value=[0.1, 0.01], interval=[(0, 10), (15, 20)]) """ def __init__(self, obj: object, parameter_name: str, start_value: float, end_value: Union[float, List[float]], interval: Union[int , Tuple[int, int], List[Tuple[int, int]]], ) -> None: super(LinearScheduler, self).__init__(obj=obj, parameter_name=parameter_name) if isinstance(end_value, float): end_value = [end_value] if isinstance(interval, int): if interval <= 0: raise ValueError("Interval must be greater than 0") interval = (0, interval) if isinstance(interval, tuple): interval = [interval] if len(end_value) != len(interval): raise ValueError(f"Length of end_value {len(end_value)} and interval {len(interval)} must be the same") self.check_intervals(interval) self.start_value = start_value self.end_value = end_value self.interval = interval self.current_value = start_value # Calculate the rates for each interval values = [self.start_value] + end_value value_steps = [values[i+1] - values[i] for i in range(len(values)-1)] interval_steps = [i[1] - i[0] for i in interval] self.rates = [x/y for x, y in zip(value_steps, interval_steps)]
[docs] def update(self, current_step: int ) -> None: """ Returns the linearly scheduled parameter value based on the current step number. Args: current_step (int): Current step number """ # Check if the current step is within any of the intervals for i, (start, end) in enumerate(self.interval): if start <= current_step <= end: if i == 0: start_val = self.start_value else: start_val = self.end_value[i-1] self.current_value = (current_step - start) * self.rates[i] + start_val self.current_value = max(self.current_value, self.end_value[i]) if self.rates[i] < 0 else min(self.current_value, self.end_value[i]) break self._set_value(self.current_value)
[docs] def check_intervals(self, intervals: list[tuple[int, int]]) -> None: """ Check if the intervals are overlapping. Args: intervals (list[tuple[int, int]]): List of intervals to check Raises: ValueError: If any of the intervals overlap """ for i in range(len(intervals)): for j in range(i + 1, len(intervals)): if self._is_interval_overlapping(intervals[i], intervals[j]): raise ValueError(f"Intervals {intervals[i]} and {intervals[j]} overlap.")
def _is_interval_overlapping(self, interval1: tuple[int, int], interval2: tuple[int, int]) -> bool: """ Check if two intervals overlap. Args: interval1 (tuple[int, int]): First interval interval2 (tuple[int, int]): Second interval Returns: bool: True if the intervals overlap, False otherwise """ a_start, a_end = interval1 b_start, b_end = interval2 return max(a_start, b_start) < min(a_end, b_end)
[docs] class ExponentialScheduler(ParameterScheduler): """ Exponential scheduler updates a parameter from a maximum value to a minimum value with a given exponential decay. Args: parameter_name (str): Name of the parameter to schedule start_value (float): Maximum value for the parameter end_value (float): Minimum value for the parameter decay_rate (float): Exponential decay rate for the parameter """ def __init__(self, obj: object, parameter_name: str, start_value: float, end_value: float, decay_rate: float, ) -> None: super(ExponentialScheduler, self).__init__(obj=obj, parameter_name=parameter_name) self.start_value = start_value self.end_value = end_value self.decay_rate = decay_rate
[docs] def update(self, current_step: int ) -> None: """ Returns the updated parameter value based on the current step number. Args: current_step (int): Current step number """ param_value = self.end_value + (self.start_value - self.end_value) * np.exp(-self.decay_rate * current_step) self._set_value(param_value)