Skip to content

Commit 4abdd8b

Browse files
author
sfwydyc
committed
first commit
1 parent 702fec2 commit 4abdd8b

File tree

175 files changed

+11132
-22
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

175 files changed

+11132
-22
lines changed

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Apache License
1+
Apache License
22
Version 2.0, January 2004
33
http://www.apache.org/licenses/
44

README.md

+189-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,189 @@
1-
# RocketQA
2-
RocketQA
1+
# RocketQA End-to-End QA-system Development Tool
2+
3+
This repository provides a simple and efficient toolkit for running RocketQA models and build a Question Answering (QA) system.
4+
5+
## RocketQA
6+
**RocketQA** is a series of dense retrieval models for Open-Domain QA.
7+
8+
Open-Domain QA aims to find the answers of natural language questions from a large collection of documents. Common approaches often contain two stages, firstly a dense retriever selects a few relevant contexts, and then a neural reader extracts the answer.
9+
10+
RocketQA focuses on improving the dense contexts retrieval stage, and propose the following methods:
11+
#### 1. [RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/pdf/2010.08191.pdf)
12+
13+
#### 2. [PAIR: Leveraging Passage-Centric Similarity Relation for Improving Dense Passage Retrieval](https://aclanthology.org/2021.findings-acl.191.pdf)
14+
15+
#### 3. [RocketQAv2: A Joint Training Method for Dense Passage Retrieval and Passage Re-ranking](https://arxiv.org/pdf/2110.07367.pdf)
16+
17+
18+
## Features
19+
* ***State-of-the-art***, RocketQA models achieve SOTA performance on MSMARCO passage ranking dataset and Natural Question dataset.
20+
* ***First-Chinese-model***, RocketQA-zh is the first open source Chinese dense retrieval model.
21+
* ***Easy-to-use***, both python installation package and DOCKER environment are provided.
22+
* ***Solution-for-QA-system***, developers can build an End-to-End QA system with one line of code.
23+
24+
25+
26+
## Installation
27+
28+
### Install python package
29+
First, install [PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html).
30+
```bash
31+
# GPU version:
32+
$ pip install paddlepaddle-gpu
33+
34+
# CPU version:
35+
$ pip install paddlepaddle
36+
```
37+
38+
Second, install rocketqa package:
39+
```bash
40+
$ pip install rocketqa
41+
```
42+
43+
NOTE: RocketQA package MUST be running on Python3.6+ with [PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html) 2.0+ :
44+
45+
### Download Docker environment
46+
47+
```bash
48+
docker pull rocketqa/rocketqa
49+
50+
docker run -it docker.io/rocketqa/rocketqa bash
51+
```
52+
53+
54+
## API
55+
The RocketQA development tool supports two types of models, ERNIE-based dual encoder for answer retrieval and ERNIE-based cross encoder for answer re-ranking. And the development tool provides the following methods:
56+
57+
#### [`rocketqa.available_models()`](https://github.com/PaddlePaddle/RocketQA/blob/3a99cf2720486df8cc54acc0e9ce4cbcee993413/rocketqa/rocketqa.py#L17)
58+
59+
Returns the names of the available RocketQA models.
60+
61+
#### [`rocketqa.load_model(model, use_cuda=False, device_id=0, batch_size=1)`](https://github.com/PaddlePaddle/RocketQA/blob/3a99cf2720486df8cc54acc0e9ce4cbcee993413/rocketqa/rocketqa.py#L52)
62+
63+
Returns the model specified by the input parameter. Both dual encoder and cross encoder can be initialized by this method. With input parameter, developers can load RocketQA models returned by "available_models()" or their own checkpoints.
64+
65+
---
66+
67+
Dual-encoder returned by "load_model()" supports the following methods:
68+
69+
#### [`model.encode_query(query: List[str])`](https://github.com/PaddlePaddle/RocketQA/blob/3a99cf2720486df8cc54acc0e9ce4cbcee993413/rocketqa/predict/dual_encoder.py#L126)
70+
71+
Given a list of queries, returns their representation vectors encoded by model.
72+
73+
#### [`model.encode_para(para: List[str], title: List[str])`](https://github.com/PaddlePaddle/RocketQA/blob/3a99cf2720486df8cc54acc0e9ce4cbcee993413/rocketqa/predict/dual_encoder.py#L154)
74+
75+
Given a list of passages and their corresponding titles (optional), returns their representations vectors encoded by model.
76+
77+
#### [`model.matching(query: List[str], para: List[str], title: List[str])`](https://github.com/PaddlePaddle/RocketQA/blob/3a99cf2720486df8cc54acc0e9ce4cbcee993413/rocketqa/predict/dual_encoder.py#L187)
78+
79+
Given a list of queries and passages (and titles), returns their matching scores (dot product between two representation vectors).
80+
81+
---
82+
83+
Cross-encoder returned by "load_model()" supports the following method:
84+
85+
#### [`model.matching(query: List[str], para: List[str], title: List[str])`](https://github.com/PaddlePaddle/RocketQA/blob/3a99cf2720486df8cc54acc0e9ce4cbcee993413/rocketqa/predict/cross_encoder.py#L129)
86+
87+
Given a list of queries and passages (and titles), returns their matching scores (probability that the paragraph is the query's right answer).
88+
89+
90+
91+
## Examples
92+
93+
With the examples below, developers can run RocketQA models or their own checkpoints.
94+
95+
### Run RocketQA Model
96+
To run RocketQA models, developers should set the parameter `model` in 'load_model()' method with RocketQA model name return by 'available_models()' method.
97+
98+
```python
99+
import rocketqa
100+
101+
query_list = ["trigeminal definition"]
102+
para_list = [
103+
"Definition of TRIGEMINAL. : of or relating to the trigeminal nerve.ADVERTISEMENT. of or relating to the trigeminal nerve. ADVERTISEMENT."]
104+
105+
# init dual encoder
106+
dual_encoder = rocketqa.load_model(model="v1_marco_de", use_cuda=True, batch_size=16)
107+
108+
# encode query & para
109+
q_embs = dual_encoder.encode_question(query=query_list)
110+
p_embs = dual_encoder.encode_passage(para=para_list)
111+
# compute dot product of query representation and para representation
112+
dot_products = dual_encoder.matching(query=query_list, para=para_list)
113+
```
114+
115+
### Run Self-development Model
116+
To run checkpoints, developers should write a config file, and set the parameter `model` in 'load_model()' method with the path of the config file.
117+
118+
```python
119+
import rocketqa
120+
121+
query_list = ["交叉验证的作用"]
122+
title_list = ["交叉验证的介绍"]
123+
para_list = ["交叉验证(Cross-validation)主要用于建模应用中,例如PCR 、PLS回归建模中。在给定的建模样本中,拿出大部分样本进行建模型,留小部分样本用刚建立的模型进行预报,并求这小部分样本的预报误差,记录它们的平方加和。"]
124+
125+
# conf
126+
ce_conf = {
127+
"model": "./own_model/config.json", # path of config file
128+
"use_cuda": True,
129+
"device_id": 0,
130+
"batch_size": 16
131+
}
132+
133+
# init cross encoder
134+
cross_encoder = rocketqa.load_model(**ce_conf)
135+
136+
# compute matching score of query and para
137+
ranking_score = cross_encoder.matching(query=query_list, para=para_list, title=title_list)
138+
```
139+
140+
The config file is a JSON format file.
141+
```bash
142+
{
143+
"model_type": "cross_encoder",
144+
"max_seq_len": 160,
145+
"model_conf_path": "en_large_config.json", # path relative to config file
146+
"model_vocab_path": "en_vocab.txt", # path relative to config file
147+
"model_checkpoint_path": "marco_cross_encoder_large", # path relative to config file
148+
"joint_training": 0
149+
}
150+
```
151+
152+
153+
154+
## Start your QA-System
155+
156+
With the examples below, developers can build own QA-System
157+
158+
### Running with JINA
159+
```bash
160+
cd examples/jina_example/
161+
pip3 install -r requirements.txt
162+
163+
# Index
164+
python3 app.py index
165+
166+
# Search
167+
python3 app.py query
168+
169+
To know more, please visit [JINA example](https://github.com/PaddlePaddle/RocketQA/tree/main/examples/jina_example)
170+
```
171+
172+
173+
174+
### Running with Faiss
175+
176+
```bash
177+
cd examples/faiss_example/
178+
pip3 install -r requirements.txt
179+
180+
# Index
181+
python3 index.py ${language} ${data_file} ${index_file}
182+
183+
# Start service
184+
python3 rocketqa_service.py ${language} ${data_file} ${index_file}
185+
186+
# request
187+
python3 query.py
188+
```
189+

examples/example.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import sys
3+
import rocketqa
4+
5+
query_list = []
6+
para_list = []
7+
title_list = []
8+
marco_q_file = 'marco.q'
9+
for line in open(marco_q_file):
10+
query_list.append(line.strip())
11+
12+
marco_tp_file = 'marco.tp.1k'
13+
for line in open(marco_tp_file):
14+
t, p = line.strip().split('\t')
15+
para_list.append(p)
16+
title_list.append(t)
17+
18+
dual_encoder = rocketqa.load_model(model="v1_marco_de", use_cuda=True, device_id=0, batch_size=32)
19+
20+
q_embs = dual_encoder.encode_question(query=query_list)
21+
for q in q_embs:
22+
print (' '.join(str(ii) for ii in q))
23+
p_embs = dual_encoder.encode_passage(para=para_list, title=title_list)
24+
for p in p_embs:
25+
print (' '.join(str(ii) for ii in p))
26+
ips = dual_encoder.matching(query=query_list, \
27+
para=para_list[:len(query_list)], \
28+
title=title_list[:len(query_list)])
29+
for ip in ips:
30+
print (ip)
31+
32+
cross_encoder = rocketqa.load_model(model="v1_marco_ce", use_cuda=True, device_id=0, batch_size=32)
33+
ranking_score = cross_encoder.matching(query=query_list, \
34+
para=para_list[:len(query_list)], \
35+
title=title_list[:len(query_list)])
36+
for rs in ranking_score:
37+
print (rs)
38+

examples/faiss_example/index.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
import sys
3+
import faiss
4+
import rocketqa
5+
6+
7+
def build_index(encoder_conf, index_file_name, title_list, para_list):
8+
9+
dual_encoder = rocketqa.load_model(**encoder_conf)
10+
para_embs = dual_encoder.encode_passage(para=para_list, title=title_list)
11+
12+
indexer = faiss.IndexFlatIP(768)
13+
indexer.add(para_embs.astype('float32'))
14+
faiss.write_index(indexer, index_file_name)
15+
16+
17+
if __name__ == '__main__':
18+
if len(sys.argv) != 4:
19+
print ("USAGE: ")
20+
print (" python3 index.py ${language} ${data_file} ${index_file}")
21+
print ("--For Example:")
22+
print (" python3 index.py zh ../marco.tp.1k marco_test.index")
23+
exit()
24+
25+
language = sys.argv[1]
26+
data_file = sys.argv[2]
27+
index_file = sys.argv[3]
28+
if language == 'zh':
29+
model = 'zh_dureader_de'
30+
elif language == 'en':
31+
model = 'v1_marco_de'
32+
else:
33+
print ("illegal language, only [zh] and [en] is supported", file=sys.stderr)
34+
exit()
35+
36+
para_list = []
37+
title_list = []
38+
for line in open(data_file):
39+
t, p = line.strip().split('\t')
40+
para_list.append(p)
41+
title_list.append(t)
42+
43+
de_conf = {
44+
"model": model,
45+
"use_cuda": True,
46+
"device_id": 0,
47+
"batch_size": 32
48+
}
49+
build_index(de_conf, index_file, title_list, para_list)

examples/faiss_example/query.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import sys
2+
import requests
3+
import json
4+
5+
SERVICE_ADD = 'http://localhost:8888/rocketqa'
6+
TOPK = 5
7+
8+
while 1:
9+
query = input("please input a query:\t")
10+
if query.strip() == '':
11+
break
12+
13+
input_data = {}
14+
input_data['query'] = query
15+
input_data['topk'] = TOPK
16+
json_str = json.dumps(input_data)
17+
18+
result = requests.post(SERVICE_ADD, json=input_data)
19+
res_json = json.loads(result.text)
20+
21+
print ("QUERY:\t" + query)
22+
for i in range(TOPK):
23+
title = res_json['answer'][i]['title']
24+
para = res_json['answer'][i]['para']
25+
score = res_json['answer'][i]['probability']
26+
print ('{}'.format(i + 1) + '\t' + title + '\t' + para + '\t' + str(score))
27+
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
faiss-cpu

0 commit comments

Comments
 (0)