1import warnings
2from itertools import islice
3from pathlib import Path
4from typing import Any, Dict, Iterator, Literal, Tuple
5
6import ir_datasets
7import numpy as np
8import pandas as pd
9import torch
10from ir_datasets.formats import GenericDoc, GenericDocPair
11from torch.distributed import get_rank, get_world_size
12from torch.utils.data import Dataset, IterableDataset, get_worker_info
13
14from .data import DocSample, QuerySample, RankSample
15from .ir_datasets_utils import ScoredDocTuple
16
17RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
18
19
[docs]
20class IRDataset:
[docs]
21 def __init__(self, dataset: str) -> None:
22 super().__init__()
23 if dataset in self.DASHED_DATASET_MAP:
24 dataset = self.DASHED_DATASET_MAP[dataset]
25 self.dataset = dataset
26 try:
27 self.ir_dataset = ir_datasets.load(dataset)
28 except KeyError:
29 self.ir_dataset = None
30 self._queries = None
31 self._docs = None
32 self._qrels = None
33
34 @property
35 def DASHED_DATASET_MAP(self) -> Dict[str, str]:
36 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}
37
38 @property
39 def queries(self) -> pd.Series:
40 if self._queries is None:
41 if self.ir_dataset is None:
42 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
43 queries_iter = self.ir_dataset.queries_iter()
44 self._queries = pd.Series(
45 {query.query_id: query.default_text() for query in queries_iter},
46 name="text",
47 )
48 self._queries.index.name = "query_id"
49 return self._queries
50
51 @property
52 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]:
53 if self._docs is None:
54 if self.ir_dataset is None:
55 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
56 self._docs = self.ir_dataset.docs_store()
57 return self._docs
58
59 @property
60 def qrels(self) -> pd.DataFrame | None:
61 if self._qrels is not None:
62 return self._qrels
63 if self.ir_dataset is None:
64 return None
65 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1)
66 if "iteration" not in qrels.columns:
67 qrels["iteration"] = 0
68 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
69 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
70 qrels = qrels.droplevel(0, axis=1)
71 self._qrels = qrels
72 return self._qrels
73
74 @property
75 def dataset_id(self) -> str:
76 if self.ir_dataset is None:
77 return self.dataset
78 return self.ir_dataset.dataset_id()
79
80 @property
81 def docs_dataset_id(self) -> str:
82 return ir_datasets.docs_parent_id(self.dataset_id)
83
84
[docs]
85class DataParallelIterableDataset(IterableDataset):
86 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734
[docs]
87 def __init__(self) -> None:
88 super().__init__()
89 # TODO add support for multi-gpu and multi-worker inference; currently
90 # doesn't work
91 worker_info = get_worker_info()
92 num_workers = worker_info.num_workers if worker_info is not None else 1
93 worker_id = worker_info.id if worker_info is not None else 0
94
95 try:
96 world_size = get_world_size()
97 process_rank = get_rank()
98 except (RuntimeError, ValueError):
99 world_size = 1
100 process_rank = 0
101
102 self.num_replicas = num_workers * world_size
103 self.rank = process_rank * num_workers + worker_id
104
105
[docs]
106class QueryDataset(IRDataset, DataParallelIterableDataset):
[docs]
107 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None:
108 super().__init__(query_dataset)
109 super(IRDataset, self).__init__()
110 self.num_queries = num_queries
111
112 def __len__(self) -> int:
113 # TODO fix len for multi-gpu and multi-worker inference
114 return self.num_queries or self.ir_dataset.queries_count()
115
116 def __iter__(self) -> Iterator[QuerySample]:
117 start = self.rank
118 stop = self.num_queries
119 step = self.num_replicas
120 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step):
121 query_sample = QuerySample.from_ir_dataset_sample(sample)
122 if self.qrels is not None:
123 qrels = (
124 self.qrels.loc[[query_sample.query_id]]
125 .stack()
126 .rename("relevance")
127 .astype(int)
128 .reset_index()
129 .to_dict(orient="records")
130 )
131 query_sample.qrels = qrels
132 yield query_sample
133
134
[docs]
135class DocDataset(IRDataset, DataParallelIterableDataset):
[docs]
136 def __init__(self, doc_dataset: str, num_docs: int | None = None) -> None:
137 super().__init__(doc_dataset)
138 super(IRDataset, self).__init__()
139 self.num_docs = num_docs
140
141 def __len__(self) -> int:
142 # TODO fix len for multi-gpu and multi-worker inference
143 return self.num_docs or self.ir_dataset.docs_count()
144
145 def __iter__(self) -> Iterator[DocSample]:
146 start = self.rank
147 stop = self.num_docs
148 step = self.num_replicas
149 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step):
150 yield DocSample.from_ir_dataset_sample(sample)
151
152
[docs]
153class Sampler:
154
155 @staticmethod
156 def single_relevant(group: pd.DataFrame, sample_size: int) -> pd.DataFrame:
157 relevance = group.filter(like="relevance").max(axis=1).fillna(0)
158 relevant = group.loc[relevance.gt(0)].sample(1)
159 non_relevant_bool = relevance.eq(0) & ~group["rank"].isna()
160 num_non_relevant = non_relevant_bool.sum()
161 sample_non_relevant = min(sample_size - 1, num_non_relevant)
162 non_relevant = group.loc[non_relevant_bool].sample(sample_non_relevant)
163 return pd.concat([relevant, non_relevant])
164
165 @staticmethod
166 def top(group: pd.DataFrame, sample_size: int) -> pd.DataFrame:
167 return group.head(sample_size)
168
169 @staticmethod
170 def top_and_random(group: pd.DataFrame, sample_size: int) -> pd.DataFrame:
171 top_size = sample_size // 2
172 random_size = sample_size - top_size
173 top = group.head(top_size)
174 random = group.iloc[top_size:].sample(random_size)
175 return pd.concat([top, random])
176
177 @staticmethod
178 def random(group: pd.DataFrame, sample_size: int) -> pd.DataFrame:
179 return group.sample(sample_size)
180
181 @staticmethod
182 def log_random(group: pd.DataFrame, sample_size: int) -> pd.DataFrame:
183 weights = 1 / np.log1p(group["rank"])
184 weights[weights.isna()] = weights.min()
185 return group.sample(sample_size, weights=weights)
186
187 @staticmethod
188 def sample(
189 df: pd.DataFrame,
190 sample_size: int,
191 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"],
192 ) -> pd.DataFrame:
193 if sample_size == -1:
194 return df
195 if hasattr(Sampler, sampling_strategy):
196 return getattr(Sampler, sampling_strategy)(df, sample_size)
197 raise ValueError("Invalid sampling strategy.")
198
199
[docs]
200class RunDataset(IRDataset, Dataset):
[docs]
201 def __init__(
202 self,
203 run_path_or_id: Path | str,
204 depth: int = -1,
205 sample_size: int = -1,
206 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top",
207 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
208 normalize_targets: bool = False,
209 add_non_retrieved_docs: bool = False,
210 ) -> None:
211 self.run_path = None
212 if Path(run_path_or_id).is_file():
213 self.run_path = Path(run_path_or_id)
214 dataset = self.run_path.name.split(".")[0].split("__")[-1]
215 else:
216 dataset = str(run_path_or_id)
217 super().__init__(dataset)
218 self.depth = depth
219 self.sample_size = sample_size
220 self.sampling_strategy = sampling_strategy
221 self.targets = targets
222 self.normalize_targets = normalize_targets
223
224 if self.sampling_strategy == "top" and self.sample_size > self.depth:
225 warnings.warn(
226 "Sample size is greater than depth and top sampling strategy is used. "
227 "This can cause documents to be sampled that are not contained "
228 "in the run file, but that are present in the qrels."
229 )
230
231 self.run = self.load_run()
232 self.run = self.run.drop_duplicates(["query_id", "doc_id"])
233
234 if self.qrels is not None:
235 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates())
236 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique()
237 query_ids = run_query_ids.intersection(qrels_query_ids)
238 if len(run_query_ids.difference(qrels_query_ids)):
239 self.run = self.run[self.run["query_id"].isin(query_ids)]
240 # outer join if docs are from ir_datasets else only keep docs in run
241 how = "left"
242 if self._docs is None and add_non_retrieved_docs:
243 how = "outer"
244 self.run = self.run.merge(
245 self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix("relevance_", axis=1),
246 on=["query_id", "doc_id"],
247 how=how,
248 )
249
250 if self.sample_size != -1:
251 num_docs_per_query = self.run.groupby("query_id").transform("size")
252 self.run = self.run[num_docs_per_query >= self.sample_size]
253
254 self.run = self.run.sort_values(["query_id", "rank"])
255 self.run_groups = self.run.groupby("query_id")
256 self.query_ids = list(self.run_groups.groups.keys())
257
258 if self.depth != -1 and self.run["rank"].max() < self.depth:
259 warnings.warn("Depth is greater than the maximum rank in the run file.")
260
261 @staticmethod
262 def load_csv(path: Path) -> pd.DataFrame:
263 return pd.read_csv(
264 path,
265 sep=r"\s+",
266 header=None,
267 names=RUN_HEADER,
268 usecols=[0, 2, 3, 4],
269 dtype={"query_id": str, "doc_id": str},
270 )
271
272 @staticmethod
273 def load_parquet(path: Path) -> pd.DataFrame:
274 return pd.read_parquet(path).rename(
275 {
276 "qid": "query_id",
277 "docid": "doc_id",
278 "docno": "doc_id",
279 },
280 axis=1,
281 )
282
283 @staticmethod
284 def load_json(path: Path) -> pd.DataFrame:
285 kwargs: Dict[str, Any] = {}
286 if ".jsonl" in path.suffixes:
287 kwargs["lines"] = True
288 kwargs["orient"] = "records"
289 run = pd.read_json(
290 path,
291 **kwargs,
292 dtype={
293 "query_id": str,
294 "qid": str,
295 "doc_id": str,
296 "docid": str,
297 "docno": str,
298 },
299 )
300 return run
301
302 def _get_run_path(self) -> Path | None:
303 run_path = self.run_path
304 if run_path is None:
305 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs():
306 raise ValueError("Run file or dataset with scoreddocs required.")
307 try:
308 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path()
309 except NotImplementedError:
310 pass
311 return run_path
312
313 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame:
314 run = run.rename(
315 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id"},
316 axis=1,
317 )
318 if "query" in run.columns:
319 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text")
320 run = run.drop("query", axis=1)
321 if "text" in run.columns:
322 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict()
323 run = run.drop("text", axis=1)
324 if self.depth != -1:
325 run = run[run["rank"] <= self.depth]
326 dtypes = {"rank": np.int32}
327 if "score" in run.columns:
328 dtypes["score"] = np.float32
329 run = run.astype(dtypes)
330 return run
331
332 def load_run(self) -> pd.DataFrame:
333
334 suffix_load_map = {
335 ".tsv": self.load_csv,
336 ".run": self.load_csv,
337 ".csv": self.load_csv,
338 ".parquet": self.load_parquet,
339 ".json": self.load_json,
340 ".jsonl": self.load_json,
341 }
342 run = None
343
344 # try loading run from file
345 run_path = self._get_run_path()
346 if run_path is not None:
347 load_func = suffix_load_map.get(run_path.suffixes[0], None)
348 if load_func is not None:
349 try:
350 run = load_func(run_path)
351 except Exception:
352 pass
353
354 # try loading run from ir_datasets
355 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
356 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
357 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
358 run = run.sort_values(["query_id", "rank"])
359
360 if run is None:
361 raise ValueError("Invalid run file format.")
362
363 run = self._clean_run(run)
364 return run
365
366 @property
367 def qrels(self) -> pd.DataFrame | None:
368 if self._qrels is not None:
369 return self._qrels
370 if "relevance" in self.run:
371 qrels = self.run[["query_id", "doc_id", "relevance"]].copy()
372 if "iteration" in self.run:
373 qrels["iteration"] = self.run["iteration"]
374 else:
375 qrels["iteration"] = "0"
376 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore")
377 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
378 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
379 qrels = qrels.droplevel(0, axis=1)
380 self._qrels = qrels
381 return self._qrels
382 return super().qrels
383
384 def __len__(self) -> int:
385 return len(self.query_ids)
386
387 def __getitem__(self, idx: int) -> RankSample:
388 query_id = str(self.query_ids[idx])
389 group = self.run_groups.get_group(query_id).copy()
390 query = self.queries[query_id]
391 group = Sampler.sample(group, self.sample_size, self.sampling_strategy)
392
393 doc_ids = tuple(group["doc_id"])
394 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
395
396 targets = None
397 if self.targets is not None:
398 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0)
399 if filtered.empty:
400 raise ValueError(f"targets `{self.targets}` not found in run file")
401 targets = torch.from_numpy(filtered.values)
402 if self.targets == "rank":
403 # invert ranks to be higher is better (necessary for loss functions)
404 targets = self.depth - targets + 1
405 if self.normalize_targets:
406 targets_min = targets.min()
407 targets_max = targets.max()
408 targets = (targets - targets_min) / (targets_max - targets_min)
409 qrels = None
410 if self.qrels is not None:
411 qrels = (
412 self.qrels.loc[[query_id]]
413 .stack()
414 .rename("relevance")
415 .astype(int)
416 .reset_index()
417 .to_dict(orient="records")
418 )
419 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
420
421
[docs]
422class TupleDataset(IRDataset, IterableDataset):
[docs]
423 def __init__(
424 self,
425 tuples_dataset: str,
426 targets: Literal["order", "score"] = "order",
427 num_docs: int | None = None,
428 ) -> None:
429 super().__init__(tuples_dataset)
430 super(IRDataset, self).__init__()
431 self.targets = targets
432 self.num_docs = num_docs
433
434 def parse_sample(
435 self, sample: ScoredDocTuple | GenericDocPair
436 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]:
437 if isinstance(sample, GenericDocPair):
438 if self.targets == "score":
439 raise ValueError("ScoredDocTuple required for score targets.")
440 targets = (1.0, 0.0)
441 doc_ids = (sample.doc_id_a, sample.doc_id_b)
442 elif isinstance(sample, ScoredDocTuple):
443 doc_ids = sample.doc_ids[: self.num_docs]
444 if self.targets == "score":
445 if sample.scores is None:
446 raise ValueError("tuples dataset does not contain scores")
447 targets = sample.scores
448 elif self.targets == "order":
449 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1))
450 else:
451 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)")
452 targets = targets[: self.num_docs]
453 else:
454 raise ValueError("Invalid sample type.")
455 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
456 return doc_ids, docs, targets
457
458 def __iter__(self) -> Iterator[RankSample]:
459 for sample in self.ir_dataset.docpairs_iter():
460 query_id = sample.query_id
461 query = self.queries.loc[query_id]
462 doc_ids, docs, targets = self.parse_sample(sample)
463 if targets is not None:
464 targets = torch.tensor(targets)
465 yield RankSample(query_id, query, doc_ids, docs, targets)