Skip to content

Commit 3795036

Browse files
[Core][Retrieval] Implement NDCG metric (run-llama#14100)
* Implement NDCG metric * Update noteboook
1 parent f175139 commit 3795036

File tree

3 files changed

+216
-36
lines changed

3 files changed

+216
-36
lines changed

docs/docs/examples/evaluation/retrieval/retriever_eval.ipynb

+57-32
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"\n",
1919
"This notebook uses our `RetrieverEvaluator` to evaluate the quality of any Retriever module defined in LlamaIndex.\n",
2020
"\n",
21-
"We specify a set of different evaluation metrics: this includes hit-rate and MRR. For any given question, these will compare the quality of retrieved results from the ground-truth context.\n",
21+
"We specify a set of different evaluation metrics: this includes hit-rate, MRR, and NDCG. For any given question, these will compare the quality of retrieved results from the ground-truth context.\n",
2222
"\n",
2323
"To ease the burden of creating the eval dataset in the first place, we can rely on synthetic data generation."
2424
]
@@ -40,13 +40,14 @@
4040
"metadata": {},
4141
"outputs": [],
4242
"source": [
43-
"%pip install llama-index-llms-openai"
43+
"%pip install llama-index-llms-openai\n",
44+
"%pip install llama-index-readers-file"
4445
]
4546
},
4647
{
4748
"cell_type": "code",
4849
"execution_count": null,
49-
"id": "bb6fecf4-7215-4ae9-b02b-3cb7c6000f2c",
50+
"id": "285cfab2",
5051
"metadata": {},
5152
"outputs": [],
5253
"source": [
@@ -62,7 +63,6 @@
6263
"metadata": {},
6364
"outputs": [],
6465
"source": [
65-
"from llama_index.core.evaluation import generate_question_context_pairs\n",
6666
"from llama_index.core import VectorStoreIndex, SimpleDirectoryReader\n",
6767
"from llama_index.core.node_parser import SentenceSplitter\n",
6868
"from llama_index.llms.openai import OpenAI"
@@ -82,7 +82,25 @@
8282
"execution_count": null,
8383
"id": "589c112d",
8484
"metadata": {},
85-
"outputs": [],
85+
"outputs": [
86+
{
87+
"name": "stdout",
88+
"output_type": "stream",
89+
"text": [
90+
"--2024-06-12 23:57:02-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt\n",
91+
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...\n",
92+
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
93+
"HTTP request sent, awaiting response... 200 OK\n",
94+
"Length: 75042 (73K) [text/plain]\n",
95+
"Saving to: ‘data/paul_graham/paul_graham_essay.txt’\n",
96+
"\n",
97+
"data/paul_graham/pa 100%[===================>] 73.28K --.-KB/s in 0.08s \n",
98+
"\n",
99+
"2024-06-12 23:57:03 (864 KB/s) - ‘data/paul_graham/paul_graham_essay.txt’ saved [75042/75042]\n",
100+
"\n"
101+
]
102+
}
103+
],
86104
"source": [
87105
"!mkdir -p 'data/paul_graham/'\n",
88106
"!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'"
@@ -171,15 +189,11 @@
171189
{
172190
"data": {
173191
"text/markdown": [
174-
"**Node ID:** node_0<br>**Similarity:** 0.8181379514114543<br>**Text:** What I Worked On\n",
175-
"\n",
176-
"February 2021\n",
192+
"**Node ID:** node_38<br>**Similarity:** 0.814377909267451<br>**Text:** I also worked on spam filters, and did some more painting. I used to have dinners for a group of friends every thursday night, which taught me how to cook for groups. And I bought another building in Cambridge, a former candy factory (and later, twas said, porn studio), to use as an office.\n",
177193
"\n",
178-
"Before college the two main things I worked on, outside of school, were writing and programming. I didn't write essays. I wrote what beginning writers were supposed to write then, and probably still are: short stories. My stories were awful. They had hardly any plot, just characters with strong feelings, which I imagined made them deep.\n",
179-
"\n",
180-
"The first programs I tried writing were on the IBM 1401 that our school district used for what was then called \"data processing.\" This was in 9th grade, so I was 13 or 14. The school district's 1401 happened to be in the basement of our junior high school, and my friend Rich Draves and I got permission to use it. It was like a mini Bond villain's lair down there, with all these alien-looking machines — CPU, disk drives, printer, card reader — sitting up on a raised floor under bright fluorescent lights.\n",
194+
"One night in October 2003 there was a big party at my house. It was a clever idea of my friend Maria Daniels, who was one of the thursday diners. Three separate hosts would all invite their friends to one party. So for every guest, two thirds of the other guests would be people they didn't know but would probably like. One of the guests was someone I didn't know but would turn out to like a lot: a woman called Jessica Livingston. A couple days later I asked her out.\n",
181195
"\n",
182-
"The language we used was an early version of Fortran. You had to type programs on punch cards, then stack them in ...<br>"
196+
"Jessica was in charge of marketing at a Boston investment bank. This bank thought it understood startups, but over the next year, as she met friends of mine from the startup world, she was surprised how different reality was. And ho...<br>"
183197
],
184198
"text/plain": [
185199
"<IPython.core.display.Markdown object>"
@@ -191,13 +205,15 @@
191205
{
192206
"data": {
193207
"text/markdown": [
194-
"**Node ID:** node_52<br>**Similarity:** 0.8143530600618721<br>**Text:** It felt like I was doing life right. I remember that because I was slightly dismayed at how novel it felt. The good news is that I had more moments like this over the next few years.\n",
208+
"**Node ID:** node_0<br>**Similarity:** 0.8122448657654567<br>**Text:** What I Worked On\n",
195209
"\n",
196-
"In the summer of 2016 we moved to England. We wanted our kids to see what it was like living in another country, and since I was a British citizen by birth, that seemed the obvious choice. We only meant to stay for a year, but we liked it so much that we still live there. So most of Bel was written in England.\n",
210+
"February 2021\n",
211+
"\n",
212+
"Before college the two main things I worked on, outside of school, were writing and programming. I didn't write essays. I wrote what beginning writers were supposed to write then, and probably still are: short stories. My stories were awful. They had hardly any plot, just characters with strong feelings, which I imagined made them deep.\n",
197213
"\n",
198-
"In the fall of 2019, Bel was finally finished. Like McCarthy's original Lisp, it's a spec rather than an implementation, although like McCarthy's Lisp it's a spec expressed as code.\n",
214+
"The first programs I tried writing were on the IBM 1401 that our school district used for what was then called \"data processing.\" This was in 9th grade, so I was 13 or 14. The school district's 1401 happened to be in the basement of our junior high school, and my friend Rich Draves and I got permission to use it. It was like a mini Bond villain's lair down there, with all these alien-looking machines — CPU, disk drives, printer, card reader — sitting up on a raised floor under bright fluorescent lights.\n",
199215
"\n",
200-
"Now that I could write essays again, I wrote a bunch about topics I'd had stacked up. I kept writing essays through 2020, but I also started to think about other things I could work on. How should I choose what to do? Well, how had I chosen what to work on in the past? I wrote an essay for myself to answer that ques...<br>"
216+
"The language we used was an early version of Fortran. You had to type programs on punch cards, then stack them in ...<br>"
201217
],
202218
"text/plain": [
203219
"<IPython.core.display.Markdown object>"
@@ -246,7 +262,15 @@
246262
"execution_count": null,
247263
"id": "2d29a159-9a4f-4d44-9c0d-1cd683f8bb9b",
248264
"metadata": {},
249-
"outputs": [],
265+
"outputs": [
266+
{
267+
"name": "stderr",
268+
"output_type": "stream",
269+
"text": [
270+
"100%|██████████| 61/61 [04:59<00:00, 4.91s/it]\n"
271+
]
272+
}
273+
],
250274
"source": [
251275
"qa_dataset = generate_question_context_pairs(\n",
252276
" nodes, llm=llm, num_questions_per_chunk=2\n",
@@ -263,7 +287,7 @@
263287
"name": "stdout",
264288
"output_type": "stream",
265289
"text": [
266-
"\"Describe the transition from using the IBM 1401 to microcomputers, as mentioned in the text. What were the key differences in terms of user interaction and programming capabilities?\"\n"
290+
"\"Describe the transition from using the IBM 1401 to microcomputers, as mentioned in the text. How did this change impact the way programs were written and executed?\"\n"
267291
]
268292
}
269293
],
@@ -319,7 +343,7 @@
319343
"metadata": {},
320344
"outputs": [],
321345
"source": [
322-
"include_cohere_rerank = True\n",
346+
"include_cohere_rerank = False\n",
323347
"\n",
324348
"if include_cohere_rerank:\n",
325349
" !pip install cohere -q"
@@ -334,7 +358,7 @@
334358
"source": [
335359
"from llama_index.core.evaluation import RetrieverEvaluator\n",
336360
"\n",
337-
"metrics = [\"mrr\", \"hit_rate\"]\n",
361+
"metrics = [\"mrr\", \"hit_rate\", \"ndcg\"]\n",
338362
"\n",
339363
"if include_cohere_rerank:\n",
340364
" metrics.append(\n",
@@ -356,8 +380,8 @@
356380
"name": "stdout",
357381
"output_type": "stream",
358382
"text": [
359-
"Query: In the context provided, the author describes his early experiences with programming on an IBM 1401. Based on his description, what were some of the limitations and challenges he faced while trying to write programs on this machine?\n",
360-
"Metrics: {'mrr': 1.0, 'hit_rate': 1.0, 'cohere_rerank_relevancy': 0.99620515}\n",
383+
"Query: In the context, the author mentions his early experiences with programming on an IBM 1401. Describe the process he used to write and run a program on this machine, and explain why he found it challenging to create meaningful programs on this system.\n",
384+
"Metrics: {'mrr': 1.0, 'hit_rate': 1.0, 'ndcg': 0.6131471927654584}\n",
361385
"\n"
362386
]
363387
}
@@ -402,9 +426,10 @@
402426
"\n",
403427
" full_df = pd.DataFrame(metric_dicts)\n",
404428
"\n",
405-
" hit_rate = full_df[\"hit_rate\"].mean()\n",
406-
" mrr = full_df[\"mrr\"].mean()\n",
407-
" columns = {\"retrievers\": [name], \"hit_rate\": [hit_rate], \"mrr\": [mrr]}\n",
429+
" columns = {\n",
430+
" \"retrievers\": [name],\n",
431+
" **{k: [full_df[k].mean()] for k in metrics},\n",
432+
" }\n",
408433
"\n",
409434
" if include_cohere_rerank:\n",
410435
" crr_relevancy = full_df[\"cohere_rerank_relevancy\"].mean()\n",
@@ -443,26 +468,26 @@
443468
" <tr style=\"text-align: right;\">\n",
444469
" <th></th>\n",
445470
" <th>retrievers</th>\n",
446-
" <th>hit_rate</th>\n",
447471
" <th>mrr</th>\n",
448-
" <th>cohere_rerank_relevancy</th>\n",
472+
" <th>hit_rate</th>\n",
473+
" <th>ndcg</th>\n",
449474
" </tr>\n",
450475
" </thead>\n",
451476
" <tbody>\n",
452477
" <tr>\n",
453478
" <th>0</th>\n",
454479
" <td>top-2 eval</td>\n",
455-
" <td>0.801724</td>\n",
456-
" <td>0.685345</td>\n",
457-
" <td>0.946009</td>\n",
480+
" <td>0.643443</td>\n",
481+
" <td>0.745902</td>\n",
482+
" <td>0.410976</td>\n",
458483
" </tr>\n",
459484
" </tbody>\n",
460485
"</table>\n",
461486
"</div>"
462487
],
463488
"text/plain": [
464-
" retrievers hit_rate mrr cohere_rerank_relevancy\n",
465-
"0 top-2 eval 0.801724 0.685345 0.946009"
489+
" retrievers mrr hit_rate ndcg\n",
490+
"0 top-2 eval 0.643443 0.745902 0.410976"
466491
]
467492
},
468493
"execution_count": null,

llama-index-core/llama_index/core/evaluation/retrieval/metrics.py

+91-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import os
23
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type
34

@@ -7,6 +8,7 @@
78
BaseRetrievalMetric,
89
RetrievalMetricResult,
910
)
11+
from typing_extensions import assert_never
1012

1113
_AGG_FUNC: Dict[str, Callable] = {"mean": np.mean, "median": np.median, "max": np.max}
1214

@@ -18,8 +20,8 @@ class HitRate(BaseRetrievalMetric):
1820
- The more granular method checks for all potential matches between retrieved docs and expected docs.
1921
2022
Attributes:
21-
use_granular_hit_rate (bool): Determines whether to use the granular method for calculation.
2223
metric_name (str): The name of the metric.
24+
use_granular_hit_rate (bool): Determines whether to use the granular method for calculation.
2325
"""
2426

2527
metric_name: ClassVar[str] = "hit_rate"
@@ -77,11 +79,11 @@ class MRR(BaseRetrievalMetric):
7779
- The more granular method sums the reciprocal ranks of all relevant retrieved documents and divides by the count of relevant documents.
7880
7981
Attributes:
80-
use_granular_mrr (bool): Determines whether to use the granular method for calculation.
8182
metric_name (str): The name of the metric.
83+
use_granular_mrr (bool): Determines whether to use the granular method for calculation.
8284
"""
8385

84-
metric_name: str = "mrr"
86+
metric_name: ClassVar[str] = "mrr"
8587
use_granular_mrr: bool = False
8688

8789
def compute(
@@ -140,6 +142,91 @@ def compute(
140142
return RetrievalMetricResult(score=mrr_score)
141143

142144

145+
DiscountedGainMode = Literal["linear", "exponential"]
146+
147+
148+
def discounted_gain(*, rel: float, i: int, mode: DiscountedGainMode) -> float:
149+
# Avoid unnecessary calculations. Note that `False == 0` and `True == 1`.
150+
if rel == 0:
151+
return 0
152+
if rel == 1:
153+
return 1 / math.log2(i + 1)
154+
155+
if mode == "linear":
156+
return rel / math.log2(i + 1)
157+
elif mode == "exponential":
158+
return (2**rel - 1) / math.log2(i + 1)
159+
else:
160+
assert_never(mode)
161+
162+
163+
class NDCG(BaseRetrievalMetric):
164+
"""NDCG (Normalized Discounted Cumulative Gain) metric.
165+
166+
The position `p` is taken as the size of the query results (which is usually
167+
`top_k` of the retriever).
168+
169+
Currently only supports binary relevance
170+
(``rel=1`` if document is in ``expected_ids``, otherwise ``rel=0``)
171+
since we assume that ``expected_ids`` is unordered.
172+
173+
Attributes:
174+
metric_name (str): The name of the metric.
175+
mode (DiscountedGainMode): Determines the formula for each item in the summation.
176+
"""
177+
178+
metric_name: ClassVar[str] = "ndcg"
179+
mode: DiscountedGainMode = "linear"
180+
181+
def compute(
182+
self,
183+
query: Optional[str] = None,
184+
expected_ids: Optional[List[str]] = None,
185+
retrieved_ids: Optional[List[str]] = None,
186+
expected_texts: Optional[List[str]] = None,
187+
retrieved_texts: Optional[List[str]] = None,
188+
) -> RetrievalMetricResult:
189+
"""Compute NDCG based on the provided inputs and selected method.
190+
191+
Parameters:
192+
query (Optional[str]): The query string (not used in the current implementation).
193+
expected_ids (Optional[List[str]]): Expected document IDs, unordered by relevance.
194+
retrieved_ids (Optional[List[str]]): Retrieved document IDs, ordered by relevance from highest to lowest.
195+
expected_texts (Optional[List[str]]): Expected texts (not used in the current implementation).
196+
retrieved_texts (Optional[List[str]]): Retrieved texts (not used in the current implementation).
197+
198+
Raises:
199+
ValueError: If the necessary IDs are not provided.
200+
201+
Returns:
202+
RetrievalMetricResult: The result with the computed MRR score.
203+
"""
204+
# Checking for the required arguments
205+
if (
206+
retrieved_ids is None
207+
or expected_ids is None
208+
or not retrieved_ids
209+
or not expected_ids
210+
):
211+
raise ValueError("Retrieved ids and expected ids must be provided")
212+
213+
mode = self.mode
214+
expected_set = set(expected_ids)
215+
216+
dcg = sum(
217+
discounted_gain(rel=docid in expected_set, i=i, mode=mode)
218+
for i, docid in enumerate(retrieved_ids, start=1)
219+
)
220+
idcg = sum(
221+
discounted_gain(rel=True, i=i, mode=mode)
222+
for i in range(1, len(retrieved_ids) + 1)
223+
)
224+
225+
ndcg_score = dcg / idcg
226+
227+
return RetrievalMetricResult(score=ndcg_score)
228+
229+
143230
class CohereRerankRelevancyMetric(BaseRetrievalMetric):
144231
"""Cohere rerank relevancy metric."""
145232

@@ -209,6 +296,7 @@ def compute(
209296
METRIC_REGISTRY: Dict[str, Type[BaseRetrievalMetric]] = {
210297
"hit_rate": HitRate,
211298
"mrr": MRR,
299+
"ndcg": NDCG,
212300
"cohere_rerank_relevancy": CohereRerankRelevancyMetric,
213301
}
214302

0 commit comments

Comments
 (0)