Source code for lightning_ir.base.validation_utils
1from typing import Dict, Sequence
2
3import ir_measures
4import numpy as np
5import pandas as pd
6import torch
7
8
[docs]
9def create_run_from_scores(
10 query_ids: Sequence[str], doc_ids: Sequence[Sequence[str]], scores: torch.Tensor
11) -> pd.DataFrame:
12 """Convience function to create a run DataFrame from query and document ids and scores.
13
14 :param query_ids: Query ids
15 :type query_ids: Sequence[str]
16 :param doc_ids: Document ids
17 :type doc_ids: Sequence[Sequence[str]]
18 :param scores: Scores
19 :type scores: torch.Tensor
20 :return: DataFrame with query_id, q0, doc_id, score, rank, and system columns
21 :rtype: pd.DataFrame
22 """
23 num_docs = [len(ids) for ids in doc_ids]
24 scores = scores.float().cpu().detach().numpy().reshape(-1)
25 df = pd.DataFrame(
26 {
27 "query_id": np.array(query_ids).repeat(num_docs),
28 "q0": 0,
29 "doc_id": np.array(sum(map(lambda x: list(x), doc_ids), [])),
30 "score": scores,
31 "system": "lightning_ir",
32 }
33 )
34 df["rank"] = df.groupby("query_id")["score"].rank(ascending=False, method="first")
35
36 def key(series: pd.Series) -> pd.Series:
37 if series.name == "query_id":
38 return series.map({query_id: i for i, query_id in enumerate(query_ids)})
39 return series
40
41 df = df.sort_values(["query_id", "rank"], ascending=[True, True], key=key)
42 return df
43
44
[docs]
45def create_qrels_from_dicts(qrels: Sequence[Dict[str, int]]) -> pd.DataFrame:
46 """Convience function to create a qrels DataFrame from a list of dictionaries.
47
48 :param qrels: Mappings of doc_id -> relevance for each query, defaults to None
49 :type qrels: Sequence[Dict[str, int]]
50 :return: DataFrame with query_id, q0, doc_id, and relevance columns
51 :rtype: pd.DataFrame
52 """
53 return pd.DataFrame.from_records(qrels)
54
55
[docs]
56def evaluate_run(run: pd.DataFrame, qrels: pd.DataFrame, measures: Sequence[str]) -> Dict[str, float]:
57 """Convience function to evaluate a run against qrels using a set of measures.
58
59 .. _ir-measures: https://ir-measur.es/en/latest/index.html
60
61 :param run: Parse TREC run
62 :type run: pd.DataFrame
63 :param qrels: Parse TREC qrels
64 :type qrels: pd.DataFrame
65 :param measures: Metrics corresponding to ir-measures_ measure strings
66 :type measures: Sequence[str]
67 :return: Calculated metrics
68 :rtype: Dict[str, float]
69 """
70 parsed_measures = [ir_measures.parse_measure(measure) for measure in measures]
71 metrics = {str(measure): measure.calc_aggregate(qrels, run) for measure in parsed_measures}
72 return metrics