1from __future__ import annotations
2
3from collections import defaultdict
4from pathlib import Path
5from typing import TYPE_CHECKING, Any, Dict, List, Literal, Sequence
6
7import torch
8from lightning import LightningDataModule
9from torch.utils.data import DataLoader, IterableDataset
10
11from ..base.config import LightningIRConfig
12from ..base.tokenizer import LightningIRTokenizer
13from .data import IndexBatch, RankBatch, SearchBatch, TrainBatch
14from .dataset import DocDataset, DocSample, QueryDataset, QuerySample, RankSample, RunDataset, TupleDataset
15
16if TYPE_CHECKING:
17 from ..base import LightningIRModule
18
19
[docs]
20class LightningIRDataModule(LightningDataModule):
[docs]
21 def __init__(
22 self,
23 model_name_or_path: str | Path | None = None,
24 config: LightningIRConfig | None = None,
25 module: LightningIRModule | None = None,
26 num_workers: int = 0,
27 train_batch_size: int | None = None,
28 shuffle_train: bool = True,
29 inference_batch_size: int | None = None,
30 train_dataset: RunDataset | TupleDataset | None = None,
31 inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None = None,
32 ) -> None:
33 super().__init__()
34 if config is not None:
35 self.config = config
36 elif module is not None:
37 self.config = module.config
38 elif model_name_or_path is not None:
39 self.config = LightningIRConfig.from_pretrained(model_name_or_path)
40 else:
41 raise ValueError("Either module, config, or model_name_or_path must be provided.")
42
43 if model_name_or_path is None:
44 model_name_or_path = self.config.name_or_path
45 self.tokenizer = LightningIRTokenizer.from_pretrained(model_name_or_path, config=self.config)
46 self.num_workers = num_workers
47
48 self.train_batch_size = train_batch_size
49 self.shuffle_train = shuffle_train
50 self.inference_batch_size = inference_batch_size
51 self.train_dataset = train_dataset
52 self.inference_datasets = inference_datasets
53
54 def setup_inference(self, stage: Literal["validate", "test"]) -> None:
55 if self.inference_datasets is None:
56 return
57 for inference_dataset in self.inference_datasets:
58 if isinstance(inference_dataset, TupleDataset):
59 if stage == "test":
60 raise ValueError("Prediction cannot be performed with TupleDataset.")
61 elif isinstance(inference_dataset, RunDataset):
62 if inference_dataset.sampling_strategy == "single_relevant":
63 raise ValueError("Inference RunDataset cannot use the single_relevant sampling strategy.")
64 elif isinstance(inference_dataset, (QueryDataset, DocDataset)):
65 pass
66 else:
67 raise ValueError(
68 "Inference Dataset must be of type RunDataset, TupleDataset, QueryDataset, or DocDataset."
69 )
70
71 def setup(self, stage: Literal["fit", "validate", "test"]) -> None:
72 if stage == "fit":
73 if self.train_dataset is None:
74 raise ValueError("A training dataset and config must be provided.")
75 if stage == "fit":
76 stage = "validate"
77 self.setup_inference(stage)
78
79 def train_dataloader(self) -> DataLoader:
80 if self.train_dataset is None:
81 raise ValueError("No training dataset found.")
82 return DataLoader(
83 self.train_dataset,
84 batch_size=self.train_batch_size,
85 num_workers=self.num_workers,
86 collate_fn=self.collate_fn,
87 shuffle=(False if isinstance(self.train_dataset, IterableDataset) else self.shuffle_train),
88 prefetch_factor=16 if self.num_workers > 0 else None,
89 )
90
91 def val_dataloader(self) -> List[DataLoader]:
92 return self.inference_dataloader()
93
94 def test_dataloader(self) -> List[DataLoader]:
95 return self.inference_dataloader()
96
97 def inference_dataloader(self) -> List[DataLoader]:
98 inference_datasets = self.inference_datasets or []
99 return [
100 DataLoader(
101 dataset,
102 batch_size=self.inference_batch_size,
103 num_workers=self.num_workers,
104 collate_fn=self.collate_fn,
105 prefetch_factor=16 if self.num_workers > 0 else None,
106 )
107 for dataset in inference_datasets
108 ]
109
110 def _aggregate_samples(self, samples: Sequence[RankSample | QuerySample | DocSample]) -> Dict[str, Any]:
111 aggregated = defaultdict(list)
112 field_options = {
113 "query_id": {"extend": False},
114 "query": {"extend": False},
115 "doc_id": {"extend": False},
116 "doc_ids": {"extend": False},
117 "doc": {"extend": False},
118 "docs": {"extend": False},
119 "targets": {"extend": True},
120 "qrels": {"extend": True},
121 }
122 for sample in samples:
123 for field in sample.__dict__:
124 extend = field_options[field]["extend"]
125 key = field if field.endswith("s") else f"{field}s"
126 value = getattr(sample, field)
127 if value is None:
128 continue
129 if extend:
130 aggregated[key].extend(value)
131 else:
132 aggregated[key].append(value)
133 return aggregated
134
135 def _clean_sample(self, aggregated: Dict[str, Any]) -> Dict[str, Any]:
136 kwargs: Dict[str, Any] = dict(aggregated)
137 if "querys" in kwargs:
138 kwargs["queries"] = kwargs["querys"]
139 del kwargs["querys"]
140 if "targets" in kwargs:
141 kwargs["targets"] = torch.stack(kwargs["targets"])
142 return kwargs
143
144 def _parse_batch(
145 self, sample: RankSample | QuerySample | DocSample, **kwargs
146 ) -> RankBatch | TrainBatch | IndexBatch | SearchBatch:
147 if isinstance(sample, RankSample):
148 if "targets" in kwargs:
149 return TrainBatch(**kwargs)
150 else:
151 return RankBatch(**kwargs)
152 if isinstance(sample, QuerySample):
153 return SearchBatch(**kwargs)
154 if isinstance(sample, DocSample):
155 return IndexBatch(**kwargs)
156 raise ValueError("Invalid dataset configuration.")
157
158 def collate_fn(
159 self,
160 samples: Sequence[RankSample | QuerySample | DocSample] | RankSample | QuerySample | DocSample,
161 ) -> TrainBatch | RankBatch | IndexBatch | SearchBatch:
162 if isinstance(samples, (RankSample, QuerySample, DocSample)):
163 samples = [samples]
164 aggregated = self._aggregate_samples(samples)
165 kwargs = self._clean_sample(aggregated)
166 return self._parse_batch(samples[0], **kwargs)