Skip to content

Commit 019ad5c

Browse files
authored
Add Elasticsearch example (#68)
This closes #67.
1 parent 8cd3f6e commit 019ad5c

File tree

5 files changed

+241
-0
lines changed

5 files changed

+241
-0
lines changed

examples/es_example/README.md

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Elasticsearch
2+
3+
This example illustrates how to use RocketQA along with [Elasticsearch](https://www.elastic.co/).
4+
5+
6+
## Prerequisites
7+
8+
### Install Dependencies
9+
10+
```console
11+
$ python3 -m venv venv
12+
$ source venv/bin/activate
13+
$ pip3 install -r requirements.txt
14+
```
15+
16+
### Run Elasticsearch
17+
18+
Run Elasticsearch in development mode:
19+
20+
```console
21+
$ docker run -d --name elasticsearch -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "ELASTIC_PASSWORD=123456" elasticsearch:8.4.2
22+
```
23+
24+
## Usage
25+
26+
### Index
27+
28+
Prepare the data (stored at `data/test.tsv`) in the following format:
29+
30+
```
31+
title_1\tparagraph_1\n
32+
title_2\tparagraph_2\n
33+
...
34+
```
35+
36+
Create the index and save the data into the index:
37+
38+
```console
39+
$ curl -XPUT -u elastic:123456 -k -H "Content-Type: application/json" https://localhost:9200/test-index -d @mappings.json
40+
$ python3 index.py zh data/test.tsv test-index
41+
```
42+
43+
### Query
44+
45+
```console
46+
$ python3 query.py
47+
```

examples/es_example/index.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import argparse
4+
import os
5+
import sys
6+
7+
import faiss
8+
import numpy as np
9+
import rocketqa
10+
from elasticsearch import Elasticsearch, helpers
11+
12+
13+
class Indexer:
14+
def __init__(self, es_client, index_name, model):
15+
self.es_client = es_client
16+
self.index_name = index_name
17+
self.dual_encoder = rocketqa.load_model(
18+
model=model,
19+
use_cuda=False, # GPU: True
20+
device_id=0,
21+
batch_size=32,
22+
)
23+
24+
def index(self, tps):
25+
titles, paras = zip(*tps)
26+
embs = self.dual_encoder.encode_para(para=paras, title=titles)
27+
28+
def gen_actions():
29+
for i, emb in enumerate(embs):
30+
# Normalize the NumPy array to a unit vector to use `dot_product` similarity,
31+
# see https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params.
32+
emb = emb / np.linalg.norm(emb)
33+
yield dict(
34+
_index=self.index_name,
35+
_id=i+1,
36+
_source=dict(
37+
title=titles[i],
38+
paragraph=paras[i],
39+
vector=emb,
40+
),
41+
)
42+
return helpers.bulk(self.es_client, gen_actions())
43+
44+
45+
def main():
46+
parser = argparse.ArgumentParser()
47+
parser.add_argument('lang', choices=['zh', 'en'], help='The language')
48+
parser.add_argument('data_file', help='The data file')
49+
parser.add_argument('index_name', help='The index name')
50+
args = parser.parse_args()
51+
52+
if args.lang == 'zh':
53+
model = 'zh_dureader_de_v2'
54+
elif args.lang == 'en':
55+
model = 'v1_marco_de'
56+
57+
with open(args.data_file) as f:
58+
tps = [line.strip().split('\t') for line in f]
59+
60+
es_client = Elasticsearch(
61+
"https://localhost:9200",
62+
http_auth=("elastic", "123456"),
63+
verify_certs=False,
64+
)
65+
66+
indexer = Indexer(es_client, args.index_name, model)
67+
result = indexer.index(tps)
68+
print(result)
69+
70+
71+
if __name__ == '__main__':
72+
main()

examples/es_example/mappings.json

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"mappings": {
3+
"_source": {
4+
"excludes": [
5+
"vector"
6+
]
7+
},
8+
"properties": {
9+
"vector": {
10+
"type": "dense_vector",
11+
"dims": 768,
12+
"index": true,
13+
"similarity": "dot_product"
14+
},
15+
"title": {
16+
"type": "text"
17+
},
18+
"paragraph": {
19+
"type": "text"
20+
}
21+
}
22+
}
23+
}

examples/es_example/query.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import sys
4+
import time
5+
6+
import numpy as np
7+
import rocketqa
8+
from elasticsearch import Elasticsearch
9+
10+
11+
class Querier:
12+
def __init__(self, es_client, index_name, de_model, ce_model):
13+
self.es_client = es_client
14+
self.index_name = index_name
15+
self.dual_encoder = rocketqa.load_model(
16+
model=de_model,
17+
use_cuda=False, # GPU: True
18+
device_id=0,
19+
batch_size=32,
20+
)
21+
self.cross_encoder = rocketqa.load_model(
22+
model=ce_model,
23+
use_cuda=False, # GPU: True
24+
device_id=0,
25+
batch_size=32,
26+
)
27+
28+
def encode(self, query):
29+
embs = self.dual_encoder.encode_query(query=[query])
30+
vector = list(embs)[0]
31+
# Normalize the NumPy array to a unit vector to use `dot_product` similarity,
32+
# see https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params.
33+
vector = vector / np.linalg.norm(vector)
34+
return vector
35+
36+
def search(self, query, topk=10):
37+
vector = self.encode(query)
38+
knn = dict(
39+
field="vector",
40+
query_vector=vector,
41+
k=topk,
42+
num_candidates=100,
43+
)
44+
result = self.es_client.knn_search(index=self.index_name, knn=knn)
45+
46+
candidates = [
47+
dict(
48+
title=doc['_source']['title'],
49+
para=doc['_source']['paragraph'],
50+
)
51+
for doc in result['hits']['hits']
52+
]
53+
return candidates
54+
55+
def sort(self, query, candidates):
56+
queries = [query] * len(candidates)
57+
titles = [c['title'] for c in candidates]
58+
paras = [c['para'] for c in candidates]
59+
ranking_score = self.cross_encoder.matching(query=queries, para=paras, title=titles)
60+
61+
answers = [
62+
dict(
63+
title=titles[i],
64+
para=paras[i],
65+
score=score,
66+
)
67+
for i, score in enumerate(ranking_score)
68+
]
69+
return sorted(answers, key=lambda a: a['score'], reverse=True)
70+
71+
72+
def main():
73+
es_client = Elasticsearch(
74+
"https://localhost:9200",
75+
http_auth=("elastic", "123456"),
76+
verify_certs=False,
77+
)
78+
querier = Querier(es_client, "test-index", 'zh_dureader_de_v2', 'zh_dureader_ce_v2')
79+
80+
while True:
81+
query = input('Query: ')
82+
83+
candidates = querier.search(query)
84+
print('Candidates:')
85+
for c in candidates:
86+
print(c['title'], '\t', c['para'])
87+
88+
answers = querier.sort(query, candidates)
89+
print('Answers:')
90+
for a in answers:
91+
print(a['title'], '\t', a['para'], '\t', a['score'])
92+
93+
94+
if __name__ == '__main__':
95+
main()

examples/es_example/requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
paddlepaddle==2.3.2
2+
rocketqa==1.1.0
3+
elasticsearch==8.5.0
4+
numpy==1.21.6

0 commit comments

Comments
 (0)