Source code for lightning_ir.base.validation_utils

 1from typing import Dict, Sequence
 2
 3import ir_measures
 4import numpy as np
 5import pandas as pd
 6import torch
 7
 8
[docs] 9def create_run_from_scores( 10 query_ids: Sequence[str], doc_ids: Sequence[Sequence[str]], scores: torch.Tensor 11) -> pd.DataFrame: 12 """Convience function to create a run DataFrame from query and document ids and scores. 13 14 :param query_ids: Query ids 15 :type query_ids: Sequence[str] 16 :param doc_ids: Document ids 17 :type doc_ids: Sequence[Sequence[str]] 18 :param scores: Scores 19 :type scores: torch.Tensor 20 :return: DataFrame with query_id, q0, doc_id, score, rank, and system columns 21 :rtype: pd.DataFrame 22 """ 23 num_docs = [len(ids) for ids in doc_ids] 24 scores = scores.float().cpu().detach().numpy().reshape(-1) 25 df = pd.DataFrame( 26 { 27 "query_id": np.array(query_ids).repeat(num_docs), 28 "q0": 0, 29 "doc_id": np.array(sum(map(lambda x: list(x), doc_ids), [])), 30 "score": scores, 31 "system": "lightning_ir", 32 } 33 ) 34 df["rank"] = df.groupby("query_id")["score"].rank(ascending=False, method="first") 35 36 def key(series: pd.Series) -> pd.Series: 37 if series.name == "query_id": 38 return series.map({query_id: i for i, query_id in enumerate(query_ids)}) 39 return series 40 41 df = df.sort_values(["query_id", "rank"], ascending=[True, True], key=key) 42 return df
43 44
[docs] 45def create_qrels_from_dicts(qrels: Sequence[Dict[str, int]]) -> pd.DataFrame: 46 """Convience function to create a qrels DataFrame from a list of dictionaries. 47 48 :param qrels: Mappings of doc_id -> relevance for each query, defaults to None 49 :type qrels: Sequence[Dict[str, int]] 50 :return: DataFrame with query_id, q0, doc_id, and relevance columns 51 :rtype: pd.DataFrame 52 """ 53 return pd.DataFrame.from_records(qrels)
54 55
[docs] 56def evaluate_run(run: pd.DataFrame, qrels: pd.DataFrame, measures: Sequence[str]) -> Dict[str, float]: 57 """Convience function to evaluate a run against qrels using a set of measures. 58 59 .. _ir-measures: https://ir-measur.es/en/latest/index.html 60 61 :param run: Parse TREC run 62 :type run: pd.DataFrame 63 :param qrels: Parse TREC qrels 64 :type qrels: pd.DataFrame 65 :param measures: Metrics corresponding to ir-measures_ measure strings 66 :type measures: Sequence[str] 67 :return: Calculated metrics 68 :rtype: Dict[str, float] 69 """ 70 parsed_measures = [ir_measures.parse_measure(measure) for measure in measures] 71 metrics = {str(measure): measure.calc_aggregate(qrels, run) for measure in parsed_measures} 72 return metrics