Source code for lightning_ir.data.datamodule

  1from __future__ import annotations
  2
  3from collections import defaultdict
  4from pathlib import Path
  5from typing import TYPE_CHECKING, Any, Dict, List, Literal, Sequence
  6
  7import torch
  8from lightning import LightningDataModule
  9from torch.utils.data import DataLoader, IterableDataset
 10
 11from ..base.config import LightningIRConfig
 12from ..base.tokenizer import LightningIRTokenizer
 13from .data import IndexBatch, RankBatch, SearchBatch, TrainBatch
 14from .dataset import DocDataset, DocSample, QueryDataset, QuerySample, RankSample, RunDataset, TupleDataset
 15
 16if TYPE_CHECKING:
 17    from ..base import LightningIRModule
 18
 19
[docs] 20class LightningIRDataModule(LightningDataModule):
[docs] 21 def __init__( 22 self, 23 model_name_or_path: str | Path | None = None, 24 config: LightningIRConfig | None = None, 25 module: LightningIRModule | None = None, 26 num_workers: int = 0, 27 train_batch_size: int | None = None, 28 shuffle_train: bool = True, 29 inference_batch_size: int | None = None, 30 train_dataset: RunDataset | TupleDataset | None = None, 31 inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None = None, 32 ) -> None: 33 super().__init__() 34 if config is not None: 35 self.config = config 36 elif module is not None: 37 self.config = module.config 38 elif model_name_or_path is not None: 39 self.config = LightningIRConfig.from_pretrained(model_name_or_path) 40 else: 41 raise ValueError("Either module, config, or model_name_or_path must be provided.") 42 43 if model_name_or_path is None: 44 model_name_or_path = self.config.name_or_path 45 self.tokenizer = LightningIRTokenizer.from_pretrained(model_name_or_path, config=self.config) 46 self.num_workers = num_workers 47 48 self.train_batch_size = train_batch_size 49 self.shuffle_train = shuffle_train 50 self.inference_batch_size = inference_batch_size 51 self.train_dataset = train_dataset 52 self.inference_datasets = inference_datasets
53 54 def setup_inference(self, stage: Literal["validate", "test"]) -> None: 55 if self.inference_datasets is None: 56 return 57 for inference_dataset in self.inference_datasets: 58 if isinstance(inference_dataset, TupleDataset): 59 if stage == "test": 60 raise ValueError("Prediction cannot be performed with TupleDataset.") 61 elif isinstance(inference_dataset, RunDataset): 62 if inference_dataset.sampling_strategy == "single_relevant": 63 raise ValueError("Inference RunDataset cannot use the single_relevant sampling strategy.") 64 elif isinstance(inference_dataset, (QueryDataset, DocDataset)): 65 pass 66 else: 67 raise ValueError( 68 "Inference Dataset must be of type RunDataset, TupleDataset, QueryDataset, or DocDataset." 69 ) 70 71 def setup(self, stage: Literal["fit", "validate", "test"]) -> None: 72 if stage == "fit": 73 if self.train_dataset is None: 74 raise ValueError("A training dataset and config must be provided.") 75 if stage == "fit": 76 stage = "validate" 77 self.setup_inference(stage) 78 79 def train_dataloader(self) -> DataLoader: 80 if self.train_dataset is None: 81 raise ValueError("No training dataset found.") 82 return DataLoader( 83 self.train_dataset, 84 batch_size=self.train_batch_size, 85 num_workers=self.num_workers, 86 collate_fn=self.collate_fn, 87 shuffle=(False if isinstance(self.train_dataset, IterableDataset) else self.shuffle_train), 88 prefetch_factor=16 if self.num_workers > 0 else None, 89 ) 90 91 def val_dataloader(self) -> List[DataLoader]: 92 return self.inference_dataloader() 93 94 def test_dataloader(self) -> List[DataLoader]: 95 return self.inference_dataloader() 96 97 def inference_dataloader(self) -> List[DataLoader]: 98 inference_datasets = self.inference_datasets or [] 99 return [ 100 DataLoader( 101 dataset, 102 batch_size=self.inference_batch_size, 103 num_workers=self.num_workers, 104 collate_fn=self.collate_fn, 105 prefetch_factor=16 if self.num_workers > 0 else None, 106 ) 107 for dataset in inference_datasets 108 ] 109 110 def _aggregate_samples(self, samples: Sequence[RankSample | QuerySample | DocSample]) -> Dict[str, Any]: 111 aggregated = defaultdict(list) 112 field_options = { 113 "query_id": {"extend": False}, 114 "query": {"extend": False}, 115 "doc_id": {"extend": False}, 116 "doc_ids": {"extend": False}, 117 "doc": {"extend": False}, 118 "docs": {"extend": False}, 119 "targets": {"extend": True}, 120 "qrels": {"extend": True}, 121 } 122 for sample in samples: 123 for field in sample.__dict__: 124 extend = field_options[field]["extend"] 125 key = field if field.endswith("s") else f"{field}s" 126 value = getattr(sample, field) 127 if value is None: 128 continue 129 if extend: 130 aggregated[key].extend(value) 131 else: 132 aggregated[key].append(value) 133 return aggregated 134 135 def _clean_sample(self, aggregated: Dict[str, Any]) -> Dict[str, Any]: 136 kwargs: Dict[str, Any] = dict(aggregated) 137 if "querys" in kwargs: 138 kwargs["queries"] = kwargs["querys"] 139 del kwargs["querys"] 140 if "targets" in kwargs: 141 kwargs["targets"] = torch.stack(kwargs["targets"]) 142 return kwargs 143 144 def _parse_batch( 145 self, sample: RankSample | QuerySample | DocSample, **kwargs 146 ) -> RankBatch | TrainBatch | IndexBatch | SearchBatch: 147 if isinstance(sample, RankSample): 148 if "targets" in kwargs: 149 return TrainBatch(**kwargs) 150 else: 151 return RankBatch(**kwargs) 152 if isinstance(sample, QuerySample): 153 return SearchBatch(**kwargs) 154 if isinstance(sample, DocSample): 155 return IndexBatch(**kwargs) 156 raise ValueError("Invalid dataset configuration.") 157 158 def collate_fn( 159 self, 160 samples: Sequence[RankSample | QuerySample | DocSample] | RankSample | QuerySample | DocSample, 161 ) -> TrainBatch | RankBatch | IndexBatch | SearchBatch: 162 if isinstance(samples, (RankSample, QuerySample, DocSample)): 163 samples = [samples] 164 aggregated = self._aggregate_samples(samples) 165 kwargs = self._clean_sample(aggregated) 166 return self._parse_batch(samples[0], **kwargs)