Source code for lightning_ir.data.dataset

  1import warnings
  2from itertools import islice
  3from pathlib import Path
  4from typing import Any, Dict, Iterator, Literal, Tuple
  5
  6import ir_datasets
  7import numpy as np
  8import pandas as pd
  9import torch
 10from ir_datasets.formats import GenericDoc, GenericDocPair
 11from torch.distributed import get_rank, get_world_size
 12from torch.utils.data import Dataset, IterableDataset, get_worker_info
 13
 14from .data import DocSample, QuerySample, RankSample
 15from .ir_datasets_utils import ScoredDocTuple
 16
 17RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
 18
 19
[docs] 20class IRDataset:
[docs] 21 def __init__(self, dataset: str) -> None: 22 super().__init__() 23 if dataset in self.DASHED_DATASET_MAP: 24 dataset = self.DASHED_DATASET_MAP[dataset] 25 self.dataset = dataset 26 try: 27 self.ir_dataset = ir_datasets.load(dataset) 28 except KeyError: 29 self.ir_dataset = None 30 self._queries = None 31 self._docs = None 32 self._qrels = None
33 34 @property 35 def DASHED_DATASET_MAP(self) -> Dict[str, str]: 36 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered} 37 38 @property 39 def queries(self) -> pd.Series: 40 if self._queries is None: 41 if self.ir_dataset is None: 42 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 43 queries_iter = self.ir_dataset.queries_iter() 44 self._queries = pd.Series( 45 {query.query_id: query.default_text() for query in queries_iter}, 46 name="text", 47 ) 48 self._queries.index.name = "query_id" 49 return self._queries 50 51 @property 52 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]: 53 if self._docs is None: 54 if self.ir_dataset is None: 55 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 56 self._docs = self.ir_dataset.docs_store() 57 return self._docs 58 59 @property 60 def qrels(self) -> pd.DataFrame | None: 61 if self._qrels is not None: 62 return self._qrels 63 if self.ir_dataset is None: 64 return None 65 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1) 66 if "iteration" not in qrels.columns: 67 qrels["iteration"] = 0 68 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 69 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 70 qrels = qrels.droplevel(0, axis=1) 71 self._qrels = qrels 72 return self._qrels 73 74 @property 75 def dataset_id(self) -> str: 76 if self.ir_dataset is None: 77 return self.dataset 78 return self.ir_dataset.dataset_id() 79 80 @property 81 def docs_dataset_id(self) -> str: 82 return ir_datasets.docs_parent_id(self.dataset_id)
83 84
[docs] 85class DataParallelIterableDataset(IterableDataset): 86 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734
[docs] 87 def __init__(self) -> None: 88 super().__init__() 89 # TODO add support for multi-gpu and multi-worker inference; currently 90 # doesn't work 91 worker_info = get_worker_info() 92 num_workers = worker_info.num_workers if worker_info is not None else 1 93 worker_id = worker_info.id if worker_info is not None else 0 94 95 try: 96 world_size = get_world_size() 97 process_rank = get_rank() 98 except (RuntimeError, ValueError): 99 world_size = 1 100 process_rank = 0 101 102 self.num_replicas = num_workers * world_size 103 self.rank = process_rank * num_workers + worker_id
104 105
[docs] 106class QueryDataset(IRDataset, DataParallelIterableDataset):
[docs] 107 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None: 108 super().__init__(query_dataset) 109 super(IRDataset, self).__init__() 110 self.num_queries = num_queries
111 112 def __len__(self) -> int: 113 # TODO fix len for multi-gpu and multi-worker inference 114 return self.num_queries or self.ir_dataset.queries_count() 115 116 def __iter__(self) -> Iterator[QuerySample]: 117 start = self.rank 118 stop = self.num_queries 119 step = self.num_replicas 120 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step): 121 query_sample = QuerySample.from_ir_dataset_sample(sample) 122 if self.qrels is not None: 123 qrels = ( 124 self.qrels.loc[[query_sample.query_id]] 125 .stack() 126 .rename("relevance") 127 .astype(int) 128 .reset_index() 129 .to_dict(orient="records") 130 ) 131 query_sample.qrels = qrels 132 yield query_sample
133 134
[docs] 135class DocDataset(IRDataset, DataParallelIterableDataset):
[docs] 136 def __init__(self, doc_dataset: str, num_docs: int | None = None) -> None: 137 super().__init__(doc_dataset) 138 super(IRDataset, self).__init__() 139 self.num_docs = num_docs
140 141 def __len__(self) -> int: 142 # TODO fix len for multi-gpu and multi-worker inference 143 return self.num_docs or self.ir_dataset.docs_count() 144 145 def __iter__(self) -> Iterator[DocSample]: 146 start = self.rank 147 stop = self.num_docs 148 step = self.num_replicas 149 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step): 150 yield DocSample.from_ir_dataset_sample(sample)
151 152
[docs] 153class Sampler: 154 155 @staticmethod 156 def single_relevant(group: pd.DataFrame, sample_size: int) -> pd.DataFrame: 157 relevance = group.filter(like="relevance").max(axis=1).fillna(0) 158 relevant = group.loc[relevance.gt(0)].sample(1) 159 non_relevant_bool = relevance.eq(0) & ~group["rank"].isna() 160 num_non_relevant = non_relevant_bool.sum() 161 sample_non_relevant = min(sample_size - 1, num_non_relevant) 162 non_relevant = group.loc[non_relevant_bool].sample(sample_non_relevant) 163 return pd.concat([relevant, non_relevant]) 164 165 @staticmethod 166 def top(group: pd.DataFrame, sample_size: int) -> pd.DataFrame: 167 return group.head(sample_size) 168 169 @staticmethod 170 def top_and_random(group: pd.DataFrame, sample_size: int) -> pd.DataFrame: 171 top_size = sample_size // 2 172 random_size = sample_size - top_size 173 top = group.head(top_size) 174 random = group.iloc[top_size:].sample(random_size) 175 return pd.concat([top, random]) 176 177 @staticmethod 178 def random(group: pd.DataFrame, sample_size: int) -> pd.DataFrame: 179 return group.sample(sample_size) 180 181 @staticmethod 182 def log_random(group: pd.DataFrame, sample_size: int) -> pd.DataFrame: 183 weights = 1 / np.log1p(group["rank"]) 184 weights[weights.isna()] = weights.min() 185 return group.sample(sample_size, weights=weights) 186 187 @staticmethod 188 def sample( 189 df: pd.DataFrame, 190 sample_size: int, 191 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"], 192 ) -> pd.DataFrame: 193 if sample_size == -1: 194 return df 195 if hasattr(Sampler, sampling_strategy): 196 return getattr(Sampler, sampling_strategy)(df, sample_size) 197 raise ValueError("Invalid sampling strategy.")
198 199
[docs] 200class RunDataset(IRDataset, Dataset):
[docs] 201 def __init__( 202 self, 203 run_path_or_id: Path | str, 204 depth: int = -1, 205 sample_size: int = -1, 206 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top", 207 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None, 208 normalize_targets: bool = False, 209 add_non_retrieved_docs: bool = False, 210 ) -> None: 211 self.run_path = None 212 if Path(run_path_or_id).is_file(): 213 self.run_path = Path(run_path_or_id) 214 dataset = self.run_path.name.split(".")[0].split("__")[-1] 215 else: 216 dataset = str(run_path_or_id) 217 super().__init__(dataset) 218 self.depth = depth 219 self.sample_size = sample_size 220 self.sampling_strategy = sampling_strategy 221 self.targets = targets 222 self.normalize_targets = normalize_targets 223 224 if self.sampling_strategy == "top" and self.sample_size > self.depth: 225 warnings.warn( 226 "Sample size is greater than depth and top sampling strategy is used. " 227 "This can cause documents to be sampled that are not contained " 228 "in the run file, but that are present in the qrels." 229 ) 230 231 self.run = self.load_run() 232 self.run = self.run.drop_duplicates(["query_id", "doc_id"]) 233 234 if self.qrels is not None: 235 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates()) 236 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique() 237 query_ids = run_query_ids.intersection(qrels_query_ids) 238 if len(run_query_ids.difference(qrels_query_ids)): 239 self.run = self.run[self.run["query_id"].isin(query_ids)] 240 # outer join if docs are from ir_datasets else only keep docs in run 241 how = "left" 242 if self._docs is None and add_non_retrieved_docs: 243 how = "outer" 244 self.run = self.run.merge( 245 self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix("relevance_", axis=1), 246 on=["query_id", "doc_id"], 247 how=how, 248 ) 249 250 if self.sample_size != -1: 251 num_docs_per_query = self.run.groupby("query_id").transform("size") 252 self.run = self.run[num_docs_per_query >= self.sample_size] 253 254 self.run = self.run.sort_values(["query_id", "rank"]) 255 self.run_groups = self.run.groupby("query_id") 256 self.query_ids = list(self.run_groups.groups.keys()) 257 258 if self.depth != -1 and self.run["rank"].max() < self.depth: 259 warnings.warn("Depth is greater than the maximum rank in the run file.")
260 261 @staticmethod 262 def load_csv(path: Path) -> pd.DataFrame: 263 return pd.read_csv( 264 path, 265 sep=r"\s+", 266 header=None, 267 names=RUN_HEADER, 268 usecols=[0, 2, 3, 4], 269 dtype={"query_id": str, "doc_id": str}, 270 ) 271 272 @staticmethod 273 def load_parquet(path: Path) -> pd.DataFrame: 274 return pd.read_parquet(path).rename( 275 { 276 "qid": "query_id", 277 "docid": "doc_id", 278 "docno": "doc_id", 279 }, 280 axis=1, 281 ) 282 283 @staticmethod 284 def load_json(path: Path) -> pd.DataFrame: 285 kwargs: Dict[str, Any] = {} 286 if ".jsonl" in path.suffixes: 287 kwargs["lines"] = True 288 kwargs["orient"] = "records" 289 run = pd.read_json( 290 path, 291 **kwargs, 292 dtype={ 293 "query_id": str, 294 "qid": str, 295 "doc_id": str, 296 "docid": str, 297 "docno": str, 298 }, 299 ) 300 return run 301 302 def _get_run_path(self) -> Path | None: 303 run_path = self.run_path 304 if run_path is None: 305 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs(): 306 raise ValueError("Run file or dataset with scoreddocs required.") 307 try: 308 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path() 309 except NotImplementedError: 310 pass 311 return run_path 312 313 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame: 314 run = run.rename( 315 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id"}, 316 axis=1, 317 ) 318 if "query" in run.columns: 319 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text") 320 run = run.drop("query", axis=1) 321 if "text" in run.columns: 322 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict() 323 run = run.drop("text", axis=1) 324 if self.depth != -1: 325 run = run[run["rank"] <= self.depth] 326 dtypes = {"rank": np.int32} 327 if "score" in run.columns: 328 dtypes["score"] = np.float32 329 run = run.astype(dtypes) 330 return run 331 332 def load_run(self) -> pd.DataFrame: 333 334 suffix_load_map = { 335 ".tsv": self.load_csv, 336 ".run": self.load_csv, 337 ".csv": self.load_csv, 338 ".parquet": self.load_parquet, 339 ".json": self.load_json, 340 ".jsonl": self.load_json, 341 } 342 run = None 343 344 # try loading run from file 345 run_path = self._get_run_path() 346 if run_path is not None: 347 load_func = suffix_load_map.get(run_path.suffixes[0], None) 348 if load_func is not None: 349 try: 350 run = load_func(run_path) 351 except Exception: 352 pass 353 354 # try loading run from ir_datasets 355 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs(): 356 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter()) 357 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False) 358 run = run.sort_values(["query_id", "rank"]) 359 360 if run is None: 361 raise ValueError("Invalid run file format.") 362 363 run = self._clean_run(run) 364 return run 365 366 @property 367 def qrels(self) -> pd.DataFrame | None: 368 if self._qrels is not None: 369 return self._qrels 370 if "relevance" in self.run: 371 qrels = self.run[["query_id", "doc_id", "relevance"]].copy() 372 if "iteration" in self.run: 373 qrels["iteration"] = self.run["iteration"] 374 else: 375 qrels["iteration"] = "0" 376 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore") 377 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 378 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 379 qrels = qrels.droplevel(0, axis=1) 380 self._qrels = qrels 381 return self._qrels 382 return super().qrels 383 384 def __len__(self) -> int: 385 return len(self.query_ids) 386 387 def __getitem__(self, idx: int) -> RankSample: 388 query_id = str(self.query_ids[idx]) 389 group = self.run_groups.get_group(query_id).copy() 390 query = self.queries[query_id] 391 group = Sampler.sample(group, self.sample_size, self.sampling_strategy) 392 393 doc_ids = tuple(group["doc_id"]) 394 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 395 396 targets = None 397 if self.targets is not None: 398 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0) 399 if filtered.empty: 400 raise ValueError(f"targets `{self.targets}` not found in run file") 401 targets = torch.from_numpy(filtered.values) 402 if self.targets == "rank": 403 # invert ranks to be higher is better (necessary for loss functions) 404 targets = self.depth - targets + 1 405 if self.normalize_targets: 406 targets_min = targets.min() 407 targets_max = targets.max() 408 targets = (targets - targets_min) / (targets_max - targets_min) 409 qrels = None 410 if self.qrels is not None: 411 qrels = ( 412 self.qrels.loc[[query_id]] 413 .stack() 414 .rename("relevance") 415 .astype(int) 416 .reset_index() 417 .to_dict(orient="records") 418 ) 419 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
420 421
[docs] 422class TupleDataset(IRDataset, IterableDataset):
[docs] 423 def __init__( 424 self, 425 tuples_dataset: str, 426 targets: Literal["order", "score"] = "order", 427 num_docs: int | None = None, 428 ) -> None: 429 super().__init__(tuples_dataset) 430 super(IRDataset, self).__init__() 431 self.targets = targets 432 self.num_docs = num_docs
433 434 def parse_sample( 435 self, sample: ScoredDocTuple | GenericDocPair 436 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]: 437 if isinstance(sample, GenericDocPair): 438 if self.targets == "score": 439 raise ValueError("ScoredDocTuple required for score targets.") 440 targets = (1.0, 0.0) 441 doc_ids = (sample.doc_id_a, sample.doc_id_b) 442 elif isinstance(sample, ScoredDocTuple): 443 doc_ids = sample.doc_ids[: self.num_docs] 444 if self.targets == "score": 445 if sample.scores is None: 446 raise ValueError("tuples dataset does not contain scores") 447 targets = sample.scores 448 elif self.targets == "order": 449 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1)) 450 else: 451 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)") 452 targets = targets[: self.num_docs] 453 else: 454 raise ValueError("Invalid sample type.") 455 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 456 return doc_ids, docs, targets 457 458 def __iter__(self) -> Iterator[RankSample]: 459 for sample in self.ir_dataset.docpairs_iter(): 460 query_id = sample.query_id 461 query = self.queries.loc[query_id] 462 doc_ids, docs, targets = self.parse_sample(sample) 463 if targets is not None: 464 targets = torch.tensor(targets) 465 yield RankSample(query_id, query, doc_ids, docs, targets)