Source code for lightning_ir.lightning_utils.schedulers

  1from abc import ABC, abstractmethod
  2from typing import Any, Dict, Sequence
  3
  4from lightning import Callback, LightningModule, Trainer
  5
  6from ..base import LightningIRModule
  7
  8
[docs] 9class LambdaWarmupScheduler(ABC):
[docs] 10 def __init__( 11 self, 12 num_warmup_steps: int, 13 num_delay_steps: int = 0, 14 *args, 15 **kwargs, 16 ) -> None: 17 self.num_warmup_steps = num_warmup_steps 18 self.num_delay_steps = num_delay_steps 19 super().__init__(*args, **kwargs)
20 21 @abstractmethod 22 def value_lambda(self, current_step: int) -> float: ... 23 24 def check_delay(self, current_step: int) -> bool: 25 return current_step < self.num_delay_steps 26 27 def check_warmup(self, current_step: int) -> bool: 28 return current_step < self.num_warmup_steps + self.num_delay_steps
29 30
[docs] 31class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler): 32
[docs] 33 def __init__( 34 self, 35 num_warmup_steps: int, 36 num_training_steps: int, 37 final_value: float = 0.0, 38 num_delay_steps: int = 0, 39 *args, 40 **kwargs, 41 ) -> None: 42 self.num_training_steps = num_training_steps 43 self.final_value = final_value 44 super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)
45 46 def value_lambda(self, current_step: int) -> float: 47 if self.check_delay(current_step): 48 return 0.0 49 if self.check_warmup(current_step): 50 return (current_step - self.num_delay_steps) / self.num_warmup_steps 51 current_step = current_step - self.num_delay_steps - self.num_warmup_steps 52 remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps 53 step_size = (1 - self.final_value) / remaining_steps 54 return max(self.final_value, 1 - step_size * current_step)
55 56
[docs] 57class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler): 58 def value_lambda(self, current_step: int) -> float: 59 if self.check_delay(current_step): 60 return 0.0 61 if self.check_warmup(current_step): 62 return (current_step - self.num_delay_steps) / self.num_warmup_steps 63 return 1.0
64 65
[docs] 66class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler): 67 def value_lambda(self, current_step: int) -> float: 68 if self.check_delay(current_step): 69 return 0.0 70 if self.check_warmup(current_step): 71 return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2 72 return 1.0
73 74
[docs] 75class GenericScheduler(Callback, ABC): 76
[docs] 77 def __init__(self, keys: Sequence[str], *args, **kwargs) -> None: 78 super().__init__(*args, **kwargs) 79 self.keys = keys 80 self.values: Dict[str, float] = {}
81 82 @abstractmethod 83 def step(self, key: str, current_step: int) -> float: ... 84 85 def get_value(self, sub_keys: Sequence[str], obj: object) -> object: 86 for sub_key in sub_keys: 87 try: 88 obj = obj[int(sub_key)] 89 except ValueError: 90 obj = getattr(obj, sub_key) 91 return obj 92 93 def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None: 94 obj = self.get_value(sub_keys[:-1], obj) 95 setattr(obj, sub_keys[-1], value) 96 97 def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None: 98 for key in self.keys: 99 sub_keys = key.split(".") 100 self.values[key] = float(self.get_value(sub_keys, pl_module)) 101 102 def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: 103 step = trainer.global_step + 1 104 for key in self.keys: 105 value = self.step(key, step) 106 sub_keys = key.split(".") 107 self.set_value(sub_keys, pl_module, value) 108 109 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 110 for key in self.keys: 111 value = self.values[key] 112 sub_keys = key.split(".") 113 self.set_value(sub_keys, pl_module, value)
114 115
[docs] 116class GenericLinearSchedulerWithLinearWarmup(GenericScheduler, LinearSchedulerWithLinearWarmup): 117 def step(self, key: str, current_step: int) -> float: 118 value = self.values[key] 119 return value * self.value_lambda(current_step)
120 121
[docs] 122class GenericConstantSchedulerWithLinearWarmup(GenericScheduler, ConstantSchedulerWithLinearWarmup): 123 def step(self, key: str, current_step: int) -> float: 124 value = self.values[key] 125 return value * self.value_lambda(current_step)
126 127
[docs] 128class GenericConstantSchedulerWithQuadraticWarmup(GenericScheduler, ConstantSchedulerWithQuadraticWarmup): 129 def step(self, key: str, current_step: int) -> float: 130 value = self.values[key] 131 return value * self.value_lambda(current_step)