Skip to content

Commit 5045faa

Browse files
committed
init
1 parent 7b78c64 commit 5045faa

24 files changed

+2022
-1
lines changed

.gitignore

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
data/
2+
results/
3+
test/
4+
5+
Milvus/
6+
7+
__pycache__/
8+
.vscode/

README.md

+93-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,93 @@
1-
test
1+
# SimGRAG
2+
3+
The is the repository for the paper "SimGraphRAG: Leveraging Similar Subgraphs for Knowledge Graphs Driven Retrieval-Augmented Generation".
4+
SimGRAG is a KG-driven RAG approach that can support various KG based tasks, such as question answering and fact verification.
5+
6+
## Prerequisites
7+
8+
It supports plug-and-play usability with the following three components:
9+
- Large language model: For generation.
10+
- Embedding model: For node and relation embedding.
11+
- Vector database: store the embedding of the nodes and relations in the knowledge graph, supporting efficient similarity search.
12+
13+
This repository is built on open-source solutions of these components:
14+
- Ollama for runing the large language model of Llama 3 70B
15+
- Nomic embedding model for node and relation embedding
16+
- Milvus for vector database
17+
18+
You can replace the components with your own preference, all you need is to prepare the APIs.
19+
Next, we provide the preparation steps for the components we used.
20+
21+
### Ollama
22+
23+
Please visit the [Ollama](https://ollama.com/) website to install Ollama on your local environment.
24+
After installation, you can use the following command to run the Llama 3 70B model:
25+
```
26+
ollama run llama3:70b
27+
```
28+
Then, you can use the following command to start the service needed by SimGRAG:
29+
```
30+
bash ollama_server.sh
31+
```
32+
33+
### Nomic Embedding Model
34+
35+
You can clone the model from [here](https://huggingface.co/nomic-ai/nomic-embed-text-v1) with the following command:
36+
```
37+
mkdir -p data/raw
38+
cd data/raw
39+
git clone https://huggingface.co/nomic-ai/nomic-embed-text-v1
40+
```
41+
42+
### Milvus
43+
44+
Please visit the [Milvus](https://milvus.io/) website to install Milvus on your local environment.
45+
After installation, you can follow its documentation to start the service needed by SimGRAG.
46+
47+
## Data preparation
48+
49+
### MetaQA
50+
Please download the MetaQA dataset following the url in the [repository](https://github.com/yuyuz/MetaQA) and put it in the `data/raw` folder.
51+
52+
### FactKG
53+
Please download the FactKG dataset following the url in the [repository](https://github.com/jiho283/FactKG) and put it in the `data/raw` folder.
54+
55+
### Directonary structure
56+
After preparation, the directories should be organized as follows:
57+
```
58+
SimGraphRAG
59+
├── data
60+
│ └── raw
61+
│ ├── nomic-embed-text-v1
62+
│ ├── MetaQA
63+
│ └── FactKG
64+
├── configs
65+
├── pipeline
66+
├── prompts
67+
└── src
68+
```
69+
70+
## Configuration
71+
72+
You can find the configuration files in the `configs` folder. You can modify the configuration files to fit your needs.
73+
74+
## Runing the pipeline
75+
76+
For MetaQA, you can run the following command:
77+
```
78+
cd pipeline
79+
python metaQA_index.py
80+
python metaQA_query1hop.py
81+
python metaQA_query2hop.py
82+
python metaQA_query3hop.py
83+
```
84+
85+
For FactKG, you can run the following command:
86+
```
87+
cd pipeline
88+
python factKG_index.py
89+
python factKG_query.py
90+
```
91+
92+
The results can be found in the file that assigned to the "output_filename" in the configuration file. For example, "results/FactKG_query.txt".
93+
Each line of the result file is a dictionary, in which the key "correct" presents the correctness of the final answer.

configs/FactKG.json

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"raw_data_dir": "../data/raw/FactKG",
3+
"processed_data_dir": "../data/FactKG",
4+
"embedding_model": {
5+
"model_path": "../data/raw/nomic-embed-text-v1",
6+
"device": "cuda:0"
7+
},
8+
"vector_store_names": {
9+
"node": "FactKG_node",
10+
"relation": "FactKG_relation",
11+
"type": "FactKG_type"
12+
},
13+
"retriever": {
14+
"node_sim_topk": 16384,
15+
"relation_sim_topk": 512,
16+
"type_sim_topk": 16,
17+
"final_topk": 3,
18+
"timeout": 1800
19+
},
20+
"llm": {
21+
"model": "llama3:70b",
22+
"base_url": "http://localhost:11451/v1",
23+
"api_key": "ollama",
24+
"temperature": 0.2,
25+
"top_p": 0.1,
26+
"max_tokens": 1024
27+
},
28+
"rewrite_shot": 12,
29+
"answer_shot": 12,
30+
"output_filename": "../results/FactKG_query.txt"
31+
}

configs/metaQA_1hop.json

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"raw_data_dir": "../data/raw/metaQA",
3+
"processed_data_dir": "../data/metaQA",
4+
"hop": 1,
5+
"embedding_model": {
6+
"model_path": "../data/raw/nomic-embed-text-v1",
7+
"device": "cuda:0"
8+
},
9+
"vector_store_names": {
10+
"node": "metaQA_node",
11+
"relation": "metaQA_relation"
12+
},
13+
"retriever": {
14+
"node_sim_topk": 16,
15+
"relation_sim_topk": 16,
16+
"final_topk": 3,
17+
"timeout": 600
18+
},
19+
"llm": {
20+
"model": "llama3:70b",
21+
"base_url": "http://localhost:11451/v1",
22+
"api_key": "ollama",
23+
"temperature": 0.2,
24+
"top_p": 0.1,
25+
"max_tokens": 1024
26+
},
27+
"rewrite_shot": 12,
28+
"answer_shot": 12,
29+
"output_filename": "../results/metaQA_1hop_query.txt"
30+
}

configs/metaQA_2hop.json

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"raw_data_dir": "../data/raw/metaQA",
3+
"processed_data_dir": "../data/metaQA",
4+
"hop": 2,
5+
"embedding_model": {
6+
"model_path": "../data/raw/nomic-embed-text-v1",
7+
"device": "cuda:0"
8+
},
9+
"vector_store_names": {
10+
"node": "metaQA_node",
11+
"relation": "metaQA_relation"
12+
},
13+
"retriever": {
14+
"node_sim_topk": 16,
15+
"relation_sim_topk": 16,
16+
"final_topk": 3,
17+
"timeout": 600
18+
},
19+
"llm": {
20+
"model": "llama3:70b",
21+
"base_url": "http://localhost:11451/v1",
22+
"api_key": "ollama",
23+
"temperature": 0.2,
24+
"top_p": 0.1,
25+
"max_tokens": 1024
26+
},
27+
"rewrite_shot": 12,
28+
"answer_shot": 12,
29+
"output_filename": "../results/metaQA_2hop_query.txt"
30+
}

configs/metaQA_3hop.json

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"raw_data_dir": "../data/raw/metaQA",
3+
"processed_data_dir": "../data/metaQA",
4+
"hop": 3,
5+
"embedding_model": {
6+
"model_path": "../data/raw/nomic-embed-text-v1",
7+
"device": "cuda:0"
8+
},
9+
"vector_store_names": {
10+
"node": "metaQA_node",
11+
"relation": "metaQA_relation"
12+
},
13+
"retriever": {
14+
"node_sim_topk": 16,
15+
"relation_sim_topk": 16,
16+
"final_topk": 3,
17+
"timeout": 600
18+
},
19+
"llm": {
20+
"model": "llama3:70b",
21+
"base_url": "http://localhost:11451/v1",
22+
"api_key": "ollama",
23+
"temperature": 0.2,
24+
"top_p": 0.1,
25+
"max_tokens": 1024
26+
},
27+
"rewrite_shot": 12,
28+
"answer_shot": 12,
29+
"output_filename": "../results/metaQA_3hop_query.txt"
30+
}

ollama_server.sh

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
3+
export CUDA_VISIBLE_DEVICES=0
4+
export OLLAMA_MODELS=/usr/share/ollama/.ollama/models
5+
export OLLAMA_HOST=http://127.0.0.1:11451
6+
7+
ollama serve

pipeline/FactKG_index.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import sys
2+
sys.path.append('..')
3+
4+
import json
5+
from src.dataset import FactKG
6+
from src.indexer import Indexer
7+
8+
# load configs
9+
configs = json.load(open('../configs/FactKG.json'))
10+
11+
# load dataset
12+
dataset = FactKG(configs)
13+
KG = dataset.get_KG()
14+
type_to_nodes = dataset.get_type_to_nodes()
15+
16+
# build index
17+
indexer = Indexer(configs)
18+
indexer.build_index(KG, type_to_nodes)

pipeline/FactKG_query.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import sys
2+
sys.path.append('..')
3+
4+
import time
5+
import json
6+
from tqdm import tqdm
7+
8+
from src.llm import LLM
9+
import prompts.answer_FactKG
10+
import prompts.rewrite_FactKG
11+
from src.dataset import FactKG
12+
from src.retriever import Retriever
13+
from src.utils import check_answer
14+
from src.utils import extract_graph
15+
16+
# load configs
17+
configs = json.load(open('../configs/FactKG.json'))
18+
19+
# load dataset
20+
dataset = FactKG(configs)
21+
KG = dataset.get_KG()
22+
type_to_nodes = dataset.get_type_to_nodes()
23+
all_queries = dataset.get_queries()
24+
all_groundtruths = dataset.get_groundtruths()
25+
26+
# load LLM
27+
llm = LLM(configs)
28+
29+
# load retriever
30+
retriever = Retriever(configs, KG, type_to_nodes)
31+
32+
# run for each query
33+
def run(query, groundtruths):
34+
res = {
35+
'query': query,
36+
'groundtruths': groundtruths,
37+
'retriever_configs': configs['retriever'],
38+
'llm_configs': configs['llm'],
39+
'rewrite_shot': configs['rewrite_shot'],
40+
'answer_shot': configs['answer_shot'],
41+
}
42+
43+
try:
44+
# rewrite
45+
start = time.time()
46+
res['rewrite_prompt'] = prompts.rewrite_FactKG.get(query, shot=res['rewrite_shot'])
47+
res['rewrite_llm_output'] = llm.chat(res['rewrite_prompt'])
48+
res['rewrite_time'] = time.time() - start
49+
50+
# extract graph
51+
res['query_graph'] = extract_graph(res['rewrite_llm_output'])
52+
53+
# subgraph matching
54+
start = time.time()
55+
res['retrieval_details'] = retriever.retrieve(res['query_graph'], mode='greedy')
56+
res['evidences'] = [each[1] for each in res['retrieval_details']['results']]
57+
res['retrieval_time'] = time.time() - start
58+
59+
# answer
60+
start = time.time()
61+
res['answer_prompt'] = prompts.answer_FactKG.get(res['query'], res['evidences'], shot=res['answer_shot'])
62+
res['answer_llm_output'] = llm.chat(res['answer_prompt'])
63+
res['answer_time'] = time.time() - start
64+
65+
# check answer
66+
res['correct'] = check_answer(res['answer_llm_output'], groundtruths)
67+
68+
except Exception as e:
69+
res['error_message'] = str(e)
70+
71+
return res
72+
73+
# run for all queries
74+
result_file = configs["output_filename"]
75+
for query, groundtruths in tqdm(zip(all_queries, all_groundtruths), total=len(all_queries)):
76+
res = run(query, groundtruths)
77+
with open(result_file, 'a', encoding='utf-8') as f:
78+
f.write(json.dumps(res, ensure_ascii=False) + '\n')

pipeline/metaQA_index.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import sys
2+
sys.path.append('..')
3+
4+
import json
5+
from src.dataset import MetaQA
6+
from src.indexer import Indexer
7+
8+
# load configs
9+
configs = json.load(open('../configs/metaQA_3hop.json'))
10+
11+
# load dataset
12+
dataset = MetaQA(configs)
13+
KG = dataset.get_KG()
14+
15+
# build index
16+
indexer = Indexer(configs)
17+
indexer.build_index(KG)

0 commit comments

Comments
 (0)