Source code for lightning_ir.main

  1import os
  2import sys
  3from pathlib import Path
  4from typing import Any, Dict, List, Mapping, Set
  5
  6import torch
  7from lightning import LightningDataModule, LightningModule, Trainer
  8from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment
  9from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
 10from lightning.pytorch.loggers import WandbLogger
 11from typing_extensions import override
 12
 13import lightning_ir  # noqa: F401
 14from lightning_ir.lightning_utils.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler
 15
 16if torch.cuda.is_available():
 17    torch.set_float32_matmul_precision("medium")
 18
 19sys.path.append(str(Path.cwd()))
 20
 21os.environ["TOKENIZERS_PARALLELISM"] = "false"
 22
 23
[docs] 24class LightningIRSaveConfigCallback(SaveConfigCallback): 25 @override 26 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 27 if stage != "fit" or trainer.logger is None: 28 return 29 return super().setup(trainer, pl_module, stage)
30 31
[docs] 32class LightningIRWandbLogger(WandbLogger): 33 @property 34 def save_dir(self) -> str | None: 35 """Gets the save directory. 36 37 Returns: 38 The path to the save directory. 39 40 """ 41 if isinstance(self.experiment, DummyExperiment): 42 return None 43 return self.experiment.dir
44 45
[docs] 46class LightningIRTrainer(Trainer): 47 # TODO check that correct callbacks are registered for each subcommand 48
[docs] 49 def index( 50 self, 51 model: LightningModule | None = None, 52 dataloaders: Any | LightningDataModule | None = None, 53 ckpt_path: str | Path | None = None, 54 verbose: bool = True, 55 datamodule: LightningDataModule | None = None, 56 ) -> List[Mapping[str, float]]: 57 """Index a collection of documents.""" 58 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
59
[docs] 60 def search( 61 self, 62 model: LightningModule | None = None, 63 dataloaders: Any | LightningDataModule | None = None, 64 ckpt_path: str | Path | None = None, 65 verbose: bool = True, 66 datamodule: LightningDataModule | None = None, 67 ) -> List[Mapping[str, float]]: 68 """Search for relevant documents.""" 69 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
70
[docs] 71 def re_rank( 72 self, 73 model: LightningModule | None = None, 74 dataloaders: Any | LightningDataModule | None = None, 75 ckpt_path: str | Path | None = None, 76 verbose: bool = True, 77 datamodule: LightningDataModule | None = None, 78 ) -> List[Mapping[str, float]]: 79 """Re-rank a set of retrieved documents.""" 80 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
81 82
[docs] 83class LightningIRCLI(LightningCLI): 84 @staticmethod 85 def configure_optimizers( 86 lightning_module: LightningModule, 87 optimizer: torch.optim.Optimizer, 88 lr_scheduler: WarmupLRScheduler | None = None, 89 ) -> Any: 90 if lr_scheduler is None: 91 return optimizer 92 93 return [optimizer], [{"scheduler": lr_scheduler, "interval": lr_scheduler.interval}] 94 95 def add_arguments_to_parser(self, parser): 96 parser.add_lr_scheduler_args(tuple(LR_SCHEDULERS)) 97 parser.link_arguments("model.init_args.model_name_or_path", "data.init_args.model_name_or_path") 98 parser.link_arguments("model.init_args.config", "data.init_args.config") 99 parser.link_arguments("model", "data.init_args.model", apply_on="instantiate") 100 parser.link_arguments("trainer.max_steps", "lr_scheduler.init_args.num_training_steps") 101 102 @staticmethod 103 def subcommands() -> Dict[str, Set[str]]: 104 return { 105 "fit": LightningCLI.subcommands()["fit"], 106 "index": {"model", "dataloaders", "datamodule"}, 107 "search": {"model", "dataloaders", "datamodule"}, 108 "re_rank": {"model", "dataloaders", "datamodule"}, 109 }
110 111
[docs] 112def main(): 113 """ 114 generate config using `python main.py fit --print_config > config.yaml` 115 additional callbacks at: 116 https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks 117 118 Example: 119 To obtain a default config: 120 121 python main.py fit \ 122 --trainer.callbacks=ModelCheckpoint \ 123 --optimizer AdamW \ 124 --trainer.logger LightningIRWandbLogger \ 125 --print_config > default.yaml 126 127 To run with the default config: 128 129 python main.py fit \ 130 --config default.yaml 131 132 """ 133 LightningIRCLI( 134 trainer_class=LightningIRTrainer, 135 save_config_callback=LightningIRSaveConfigCallback, 136 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True}, 137 )
138 139 140if __name__ == "__main__": 141 main()