Last active
August 22, 2024 14:04
-
-
Save Zendelo/396324fa553aaf85c3d6d7a47eef578d to your computer and use it in GitHub Desktop.
Script to calculate nDCG scores between a TREC format run file and a Qrels file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import pandas as pd | |
import numpy as np | |
TREC_RES_COLUMNS = ['qid', 'iteration', 'docNo', 'rank', 'docScore', 'method'] | |
TREC_QREL_COLUMNS = ['qid', 'iteration', 'docNo', 'rel'] | |
parser = argparse.ArgumentParser(description='Calculates NDCG@k') | |
parser.add_argument('qrels_file', type=str, help='path to a qrels file') | |
parser.add_argument('run_file', type=str, help='path to a run file in TREC format') | |
parser.add_argument('-k', '--cut_off', default=10, type=int, help='k cutoff for NDCG@k, default is 10') | |
parser.add_argument('-no_gd', '--not_gdeval', action='store_false', help='add this flag for not gdeval equivalent calculation') | |
parser.add_argument('-b', '--log_base', default=2, help='If not using gdeval, this can be used to specify the log base') | |
def calc_ndcg(qrels_file, results_file, k, base=2, gdeval=True): | |
""" | |
Setting gdeval will produce identical results to the official evaluation script that was published for TREC. | |
Note that the calculation in that script differs from (probably) any published research version. | |
:param qrels_file: a path to a TREC style qrels file | |
:param results_file: a path to a TREC style run file (result) | |
:param k: integer to be used as cutoff | |
:param base: a number to be used as the log base for calculation if gdeval=False | |
:param gdeval: boolean parameter that indicates whether to use the gdeval calculation or not | |
:return: | |
""" | |
# Reading and sorting the qrels, to later speed-up indexing and locating | |
qrels_df = pd.read_csv(qrels_file, delim_whitespace=True, names=TREC_QREL_COLUMNS). \ | |
sort_values(['qid', 'rel', 'docNo'], ascending=[True, False, True]).set_index(['qid', 'docNo']) | |
qrels_df['rel'].clip(lower=0, inplace=True) | |
# Store beginning and end indices for each query - used for speed up | |
qid_res_len = qrels_df.groupby('qid').apply(len) | |
qid_end_loc = qid_res_len.cumsum() | |
qid_start_loc = qid_end_loc - qid_res_len | |
qrels_df = qrels_df.droplevel(0) | |
results_df = pd.read_csv(results_file, delim_whitespace=True, names=TREC_RES_COLUMNS) | |
# if calculating according to the gdeval script, the results are sorted by docScore and ties are broken by docNo | |
if gdeval: | |
results_df = results_df.sort_values(['qid', 'docScore', 'docNo'], ascending=[True, False, False]).groupby( | |
'qid').head(k) | |
discount = np.log(np.arange(1, k + 1) + 1) | |
# Otherwise, sort by doc ranks | |
else: | |
results_df = results_df.sort_values(['qid', 'rank']).groupby('qid').head(k) | |
discount = np.concatenate((np.ones(base), np.log(np.arange(base, k) + 1) / np.log(base))) | |
result = {} | |
for qid, _df in results_df.groupby('qid'): | |
docs = _df['docNo'].to_numpy() | |
try: | |
_qrels_df = qrels_df.iloc[qid_start_loc.loc[qid]: qid_end_loc.loc[qid]] | |
except KeyError as err: | |
print(f'query id {err} doesn\'t exist in the qrels file, skipping it') | |
continue | |
if gdeval: | |
dcg = 2 ** _qrels_df.reindex(docs)['rel'].fillna(0).to_numpy() - 1 | |
idcg = ((2 ** _qrels_df['rel'].head(k) - 1) / discount[:len(_qrels_df)]).sum() | |
else: | |
dcg = _qrels_df.reindex(docs)['rel'].fillna(0).to_numpy() | |
idcg = (_qrels_df['rel'].head(k) / discount[:len(_qrels_df)]).sum() | |
result[qid] = (dcg / discount[:len(dcg)]).sum() / idcg | |
res_df = pd.DataFrame.from_dict(result, orient='index', columns=[f'nDCG@{k}']) | |
# res_df.to_csv(rreplace(results_file, 'run', f'ndcg@{k}', 1), sep='\t', float_format='%.6f', header=False) | |
print(res_df.to_string(float_format='%.5f')) | |
print(f'Mean: {res_df.mean()[0]:.5f}') | |
return res_df | |
if __name__ == '__main__': | |
args = parser.parse_args() | |
calc_ndcg(args.qrels_file, args.run_file, args.cut_off, base=args.log_base, gdeval=args.gdeval) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
a usage example:
python calc_ndcg.py qrels_file results_file -k 100 > results_file.ndcg