Skip to content

Commit 6469632

Browse files
committed
Add example for multi turn end to end rag
1 parent 800b2ba commit 6469632

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from unitxt import get_logger
2+
from unitxt.api import create_dataset, evaluate
3+
from unitxt.task import Task
4+
from unitxt.templates import JsonOutputTemplate
5+
6+
logger = get_logger()
7+
8+
#
9+
contexts = [
10+
"Austin is the capital of Texas.",
11+
"Houston is in Texas",
12+
"Houston is the the largest city in the state but not the capital of it.",
13+
]
14+
15+
# Set up question answer pairs in a dictionary
16+
dataset = [
17+
{
18+
"question": "What is the capital of Texas?",
19+
"conversation_id": 0,
20+
"turn_id": 0,
21+
"reference_answers": ["Austin"],
22+
"reference_contexts": [contexts[0]],
23+
"reference_context_ids": [0],
24+
"is_answerable_label": True,
25+
},
26+
{
27+
"question": "Which is the the largest city in the state?",
28+
"conversation_id": 0,
29+
"turn_id": 1,
30+
"reference_answers": ["Houston"],
31+
"reference_contexts": [contexts[1], contexts[2]],
32+
"reference_context_ids": [1, 2],
33+
"is_answerable_label": True,
34+
},
35+
{
36+
"question": "How much is 2+2?",
37+
"conversation_id": 1,
38+
"turn_id": 0,
39+
"reference_answers": ["4"],
40+
"reference_contexts": [""],
41+
"reference_context_ids": [],
42+
"is_answerable_label": True,
43+
},
44+
{
45+
"question": "Multiply the answer by 5",
46+
"conversation_id": 1,
47+
"turn_id": 1,
48+
"reference_answers": ["20"],
49+
"reference_contexts": [""],
50+
"reference_context_ids": [],
51+
"is_answerable_label": True,
52+
},
53+
]
54+
55+
predictions = [
56+
{
57+
"answer": "Houston",
58+
"contexts": [contexts[2]],
59+
"context_ids": [2],
60+
"is_answerable": True,
61+
},
62+
{
63+
"answer": "Houston",
64+
"contexts": [contexts[2]],
65+
"context_ids": [2],
66+
"is_answerable": True,
67+
},
68+
{
69+
"answer": "4",
70+
"contexts": [""],
71+
"context_ids": [],
72+
"is_answerable": True,
73+
},
74+
{
75+
"answer": "25",
76+
"contexts": [""],
77+
"context_ids": [],
78+
"is_answerable": True,
79+
},
80+
]
81+
82+
# select recommended metrics according to your available resources.
83+
metrics = [
84+
"metrics.rag.end_to_end.recommended.cpu_only.all",
85+
# "metrics.rag.end_to_end.recommended.small_llm.all",
86+
# "metrics.rag.end_to_end.recommended.llmaj_watsonx.all",
87+
# "metrics.rag.end_to_end.recommended.llmaj_rits.all"
88+
# "metrics.rag.end_to_end.recommended.llmaj_azure.all"
89+
]
90+
91+
multi_turn_rag_task = Task(
92+
input_fields={
93+
"question": "Union[str, Dialog]",
94+
"conversation_id": "Any",
95+
"turn_id": "Any",
96+
"metadata_tags": "Dict[str, str]",
97+
},
98+
reference_fields={
99+
"reference_answers": "List[str]",
100+
"reference_contexts": "List[str]",
101+
"reference_context_ids": "Union[List[int], List[str]]",
102+
"is_answerable_label": "bool",
103+
},
104+
metrics=[
105+
"metrics.rag.end_to_end.answer_correctness",
106+
"metrics.rag.end_to_end.answer_faithfulness",
107+
"metrics.rag.end_to_end.answer_reward",
108+
"metrics.rag.end_to_end.context_correctness",
109+
"metrics.rag.end_to_end.context_relevance",
110+
],
111+
prediction_type="RagResponse",
112+
augmentable_inputs=[
113+
"question",
114+
],
115+
defaults={
116+
"metadata_tags": {},
117+
"reference_answers": [],
118+
"reference_contexts": [],
119+
"reference_context_ids": [],
120+
"is_answerable_label": True,
121+
},
122+
)
123+
124+
template = JsonOutputTemplate(
125+
input_format="Conversation: {conversation_id} Turn: {turn_id} Question: {question}",
126+
output_fields={
127+
"reference_answers": "answer",
128+
"reference_contexts": "contexts",
129+
"reference_context_ids": "context_ids",
130+
},
131+
wrap_with_list_fields=[
132+
"reference_contexts",
133+
"reference_context_ids",
134+
],
135+
postprocessors=[
136+
"processors.load_json_predictions",
137+
],
138+
)
139+
140+
dataset = create_dataset(
141+
task=multi_turn_rag_task,
142+
test_set=dataset,
143+
split="test",
144+
postprocessors=[],
145+
metrics=metrics,
146+
template=template,
147+
group_by=["conversation_id"],
148+
)
149+
150+
results = evaluate(predictions, dataset)
151+
152+
# Print Results:
153+
154+
print("Global Results:")
155+
print(results.global_scores.summary)
156+
157+
print("Instance Results:")
158+
print(results.instance_scores.summary)
159+
160+
print("Group results:")
161+
print(results.groups_scores.summary)

0 commit comments

Comments
 (0)