Source code for lightning_ir.data.ir_datasets_utils

  1import codecs
  2import json
  3from pathlib import Path
  4from typing import Any, Dict, Literal, NamedTuple, Tuple, Type
  5
  6import ir_datasets
  7from ir_datasets.datasets.base import Dataset
  8from ir_datasets.formats import BaseDocPairs, jsonl, trec, tsv
  9from ir_datasets.util import Cache, DownloadConfig, GzipExtract
 10
 11CONSTITUENT_TYPE_MAP: Dict[str, Dict[str, Type]] = {
 12    "docs": {
 13        ".json": jsonl.JsonlDocs,
 14        ".jsonl": jsonl.JsonlDocs,
 15        ".tsv": tsv.TsvDocs,
 16    },
 17    "queries": {
 18        ".json": jsonl.JsonlQueries,
 19        ".jsonl": jsonl.JsonlQueries,
 20        ".tsv": tsv.TsvQueries,
 21    },
 22    "qrels": {".tsv": trec.TrecQrels, ".qrels": trec.TrecQrels},
 23    "scoreddocs": {".run": trec.TrecScoredDocs, ".tsv": trec.TrecScoredDocs},
 24    "docpairs": {".tsv": tsv.TsvDocPairs},
 25}
 26
 27
[docs] 28def load_constituent( 29 constituent: str | None, 30 constituent_type: Literal["docs", "queries", "qrels", "scoreddocs", "docpairs"], 31 **kwargs, 32) -> Any: 33 if constituent is None: 34 return None 35 if constituent in ir_datasets.registry._registered: 36 return getattr(ir_datasets.load(constituent), f"{constituent_type}_handler") 37 constituent_path = Path(constituent) 38 if not constituent_path.exists(): 39 raise ValueError(f"unable to load {constituent}, expected an `ir_datasets` id or valid path") 40 suffix = constituent_path.suffixes[0] 41 constituent_types = CONSTITUENT_TYPE_MAP[constituent_type] 42 if suffix not in constituent_types: 43 raise ValueError(f"Unknown file type: {suffix}, expected one of {constituent_types.keys()}") 44 ConstituentType = constituent_types[suffix] 45 return ConstituentType(Cache(None, constituent_path), **kwargs)
46 47
[docs] 48def register_local( 49 dataset_id: str, 50 docs: str | None = None, 51 queries: str | None = None, 52 qrels: str | None = None, 53 docpairs: str | None = None, 54 scoreddocs: str | None = None, 55 qrels_defs: Dict[int, str] | None = None, 56): 57 if dataset_id in ir_datasets.registry._registered: 58 return 59 60 docs = load_constituent(docs, "docs") 61 queries = load_constituent(queries, "queries") 62 qrels = load_constituent(qrels, "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {}) 63 docpairs = load_constituent(docpairs, "docpairs") 64 scoreddocs = load_constituent(scoreddocs, "scoreddocs") 65 66 ir_datasets.registry.register(dataset_id, Dataset(docs, queries, qrels, docpairs, scoreddocs))
67 68
[docs] 69class ScoredDocTuple(NamedTuple): 70 query_id: str 71 doc_ids: Tuple[str, ...] 72 scores: Tuple[float, ...] | None 73 num_docs: int
74 75
[docs] 76class ScoredDocTuples(BaseDocPairs):
[docs] 77 def __init__(self, docpairs_dlc): 78 self._docpairs_dlc = docpairs_dlc
79 80 def docpairs_path(self): 81 return self._docpairs_dlc.path() 82 83 def docpairs_iter(self): 84 file_type = None 85 if self._docpairs_dlc.path().suffix == ".json": 86 file_type = "json" 87 elif self._docpairs_dlc.path().suffix in (".tsv", ".run"): 88 file_type = "tsv" 89 else: 90 raise ValueError(f"Unknown file type: {self._docpairs_dlc.path().suffix}") 91 with self._docpairs_dlc.stream() as f: 92 f = codecs.getreader("utf8")(f) 93 for line in f: 94 if file_type == "json": 95 data = json.loads(line) 96 qid, *doc_data = data 97 pids, scores = zip(*doc_data) 98 pids = tuple(str(pid) for pid in pids) 99 else: 100 cols = line.rstrip().split() 101 pos_score, neg_score, qid, pid1, pid2 = cols 102 pids = (pid1, pid2) 103 scores = (float(pos_score), float(neg_score)) 104 yield ScoredDocTuple(str(qid), pids, scores, len(pids)) 105 106 def docpairs_cls(self): 107 return ScoredDocTuple
108 109
[docs] 110def register_kd_docpairs(): 111 base_id = "msmarco-passage" 112 split_id = "train" 113 file_id = "kd-docpairs" 114 cache_path = "bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv" 115 dlc_contents = { 116 "url": ( 117 "https://zenodo.org/record/4068216/files/bert_cat_ensemble_" 118 "msmarcopassage_train_scores_ids.tsv?download=1" 119 ), 120 "expected_md5": "4d99696386f96a7f1631076bcc53ac3c", 121 "cache_path": cache_path, 122 } 123 file_name = f"{split_id}/{file_id}.tsv" 124 register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples)
125 126
[docs] 127def register_colbert_docpairs(): 128 base_id = "msmarco-passage" 129 split_id = "train" 130 file_id = "colbert-docpairs" 131 cache_path = "colbert_64way.json" 132 dlc_contents = { 133 "url": ( 134 "https://huggingface.co/colbert-ir/colbertv2.0_msmarco_64way/" "resolve/main/examples.json?download=true" 135 ), 136 "expected_md5": "8be0c71e330ac54dcd77fba058d291c7", 137 "cache_path": cache_path, 138 } 139 file_name = f"{split_id}/{file_id}.json" 140 register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples)
141 142
[docs] 143def register_rank_distillm(): 144 base_id = "msmarco-passage" 145 split_id = "train" 146 file_id = "rank-distillm/rankzephyr" 147 cache_path = "rank-distillm-rankzephyr.run" 148 dlc_contents = { 149 "url": ( 150 "https://zenodo.org/records/12528410/files/__rankzephyr-colbert-10000-" 151 "sampled-100__msmarco-passage-train-judged.run?download=1" 152 ), 153 "expected_md5": "49f8dbf2c1ee7a2ca1fe517eda528af6", 154 "cache_path": cache_path, 155 } 156 file_name = f"{split_id}/{file_id}.run" 157 register_msmarco( 158 base_id, 159 split_id, 160 file_id, 161 cache_path, 162 dlc_contents, 163 file_name, 164 trec.TrecScoredDocs, 165 ) 166 167 file_id = "rank-distillm/set-encoder" 168 cache_path = "rank-distillm-set-encoder.run.gz" 169 dlc_contents = { 170 "url": ( 171 "https://zenodo.org/records/12528410/files/__set-encoder-colbert__" 172 "msmarco-passage-train-judged.run.gz?download=1" 173 ), 174 "expected_md5": "1f069d0daa9842a54a858cc660149e1a", 175 "cache_path": cache_path, 176 } 177 file_name = f"{split_id}/{file_id}.run" 178 register_msmarco( 179 base_id, 180 split_id, 181 file_id, 182 cache_path, 183 dlc_contents, 184 file_name, 185 trec.TrecScoredDocs, 186 extract=True, 187 )
188 189
[docs] 190def register_msmarco( 191 base_id: str, 192 split_id: str, 193 file_id: str, 194 cache_path: str, 195 dlc_contents: Dict[str, Any], 196 file_name: str, 197 ConstituentType: Type, 198 extract: bool = False, 199): 200 dataset_id = f"{base_id}/{split_id}/{file_id}" 201 if dataset_id in ir_datasets.registry._registered: 202 return 203 base_path = ir_datasets.util.home_path() / base_id 204 dlc = DownloadConfig.context(base_id, base_path) 205 dlc._contents[cache_path] = dlc_contents 206 ir_dataset = ir_datasets.load(f"{base_id}/{split_id}") 207 collection = ir_dataset.docs_handler() 208 queries = ir_dataset.queries_handler() 209 qrels = ir_dataset.qrels_handler() 210 _dlc = dlc[cache_path] 211 if extract: 212 _dlc = GzipExtract(_dlc) 213 constituent = ConstituentType(Cache(_dlc, base_path / split_id / file_name)) 214 dataset = Dataset(collection, queries, qrels, constituent) 215 ir_datasets.registry.register(dataset_id, Dataset(dataset))
216 217 218register_kd_docpairs() 219register_colbert_docpairs() 220register_rank_distillm()