1
+ import sys
2
+ sys .path .append ('..' )
3
+
4
+ import time
5
+ import json
6
+ from tqdm import tqdm
7
+
8
+ from src .llm import LLM
9
+ import prompts .answer_FactKG
10
+ import prompts .rewrite_FactKG
11
+ from src .dataset import FactKG
12
+ from src .retriever import Retriever
13
+ from src .utils import check_answer
14
+ from src .utils import extract_graph
15
+
16
+ # load configs
17
+ configs = json .load (open ('../configs/FactKG.json' ))
18
+
19
+ # load dataset
20
+ dataset = FactKG (configs )
21
+ KG = dataset .get_KG ()
22
+ type_to_nodes = dataset .get_type_to_nodes ()
23
+ all_queries = dataset .get_queries ()
24
+ all_groundtruths = dataset .get_groundtruths ()
25
+
26
+ # load LLM
27
+ llm = LLM (configs )
28
+
29
+ # load retriever
30
+ retriever = Retriever (configs , KG , type_to_nodes )
31
+
32
+ # run for each query
33
+ def run (query , groundtruths ):
34
+ res = {
35
+ 'query' : query ,
36
+ 'groundtruths' : groundtruths ,
37
+ 'retriever_configs' : configs ['retriever' ],
38
+ 'llm_configs' : configs ['llm' ],
39
+ 'rewrite_shot' : configs ['rewrite_shot' ],
40
+ 'answer_shot' : configs ['answer_shot' ],
41
+ }
42
+
43
+ try :
44
+ # rewrite
45
+ start = time .time ()
46
+ res ['rewrite_prompt' ] = prompts .rewrite_FactKG .get (query , shot = res ['rewrite_shot' ])
47
+ res ['rewrite_llm_output' ] = llm .chat (res ['rewrite_prompt' ])
48
+ res ['rewrite_time' ] = time .time () - start
49
+
50
+ # extract graph
51
+ res ['query_graph' ] = extract_graph (res ['rewrite_llm_output' ])
52
+
53
+ # subgraph matching
54
+ start = time .time ()
55
+ res ['retrieval_details' ] = retriever .retrieve (res ['query_graph' ], mode = 'greedy' )
56
+ res ['evidences' ] = [each [1 ] for each in res ['retrieval_details' ]['results' ]]
57
+ res ['retrieval_time' ] = time .time () - start
58
+
59
+ # answer
60
+ start = time .time ()
61
+ res ['answer_prompt' ] = prompts .answer_FactKG .get (res ['query' ], res ['evidences' ], shot = res ['answer_shot' ])
62
+ res ['answer_llm_output' ] = llm .chat (res ['answer_prompt' ])
63
+ res ['answer_time' ] = time .time () - start
64
+
65
+ # check answer
66
+ res ['correct' ] = check_answer (res ['answer_llm_output' ], groundtruths )
67
+
68
+ except Exception as e :
69
+ res ['error_message' ] = str (e )
70
+
71
+ return res
72
+
73
+ # run for all queries
74
+ result_file = configs ["output_filename" ]
75
+ for query , groundtruths in tqdm (zip (all_queries , all_groundtruths ), total = len (all_queries )):
76
+ res = run (query , groundtruths )
77
+ with open (result_file , 'a' , encoding = 'utf-8' ) as f :
78
+ f .write (json .dumps (res , ensure_ascii = False ) + '\n ' )
0 commit comments