Skip to content

Commit ee096e4

Browse files
Michael Norrisfacebook-github-bot
Michael Norris
authored andcommitted
Add rabitq bench to source control (#4307)
Summary: Pull Request resolved: #4307 Creating new source control notebook file Differential Revision: D73549740
1 parent 82cf65a commit ee096e4

File tree

1 file changed

+337
-0
lines changed

1 file changed

+337
-0
lines changed

benchs/bench_rabitq.py

+337
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
#!/usr/bin/env -S grimaldi --kernel faiss
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# fmt: off
8+
# flake8: noqa
9+
10+
# NOTEBOOK_NUMBER: N7030784 (685760243832285)
11+
12+
""":py"""
13+
import timeit
14+
from collections import defaultdict
15+
16+
import faiss
17+
from faiss.contrib.datasets import SyntheticDataset
18+
19+
""":py"""
20+
ds: SyntheticDataset = SyntheticDataset(256, 1_000_000, 1_000_000, 10_000)
21+
nlist: int = 1000
22+
qb: int = 8
23+
# This will contain <"index name", ([recalls],[speeds],[labels (the k)])>
24+
recall_speed_data = defaultdict(lambda: [[], [], []])
25+
# This will contain <"index name", ([recalls],[memory for this index])>
26+
recall_memory_data = defaultdict(lambda: [[], []])
27+
28+
""":py"""
29+
# Helpers
30+
31+
32+
def trials(index, xq, k):
33+
trials = 10
34+
result = timeit.timeit(
35+
stmt="index.search(xq, k)",
36+
number=trials,
37+
globals={"index": index, "xq": xq, "k": k},
38+
)
39+
return result / trials * 1000.0 # ms
40+
41+
42+
def trials_ivf(index, xq, k, params=None):
43+
trials = 10
44+
result = timeit.timeit(
45+
stmt="search_with_parameters(index, xq, k, params)",
46+
number=trials,
47+
globals={
48+
"search_with_parameters": faiss.search_with_parameters,
49+
"index": index,
50+
"xq": xq,
51+
"k": k,
52+
"params": params,
53+
},
54+
)
55+
return result / trials * 1000.0 # ms
56+
57+
58+
def compute_recall(ground_truth_I, predicted_I):
59+
n_queries, k = ground_truth_I.shape
60+
intersection = faiss.eval_intersection(ground_truth_I, predicted_I)
61+
recall = intersection / (n_queries * k)
62+
return recall
63+
64+
65+
def create_index(ds, factory_string):
66+
index = faiss.index_factory(ds.d, factory_string)
67+
index.train(ds.get_train())
68+
index.add(ds.get_database())
69+
return index
70+
71+
72+
# pyre-ignore
73+
def handle_index(prefix, index, ds, mem, k):
74+
gt_I = ds.get_groundtruth(k)
75+
_, I_res = index.search(ds.get_queries(), k)
76+
avg_speed = trials(index, ds.get_queries(), k)
77+
recall = compute_recall(gt_I, I_res)
78+
print(
79+
f"{prefix} recall@{k}: {recall}. Average speed: {avg_speed:.1f}ms. Memory: {mem/1e6:.3f}MB"
80+
)
81+
recall_speed_data[prefix][0].append(recall)
82+
recall_speed_data[prefix][1].append(avg_speed)
83+
recall_speed_data[prefix][2].append(f"k={k}")
84+
recall_memory_data[prefix][0].append(recall)
85+
recall_memory_data[prefix][1].append(mem)
86+
87+
88+
# pyre-ignore
89+
def handle_ivf_index(prefix, index, ds, mem, k, params):
90+
gt_I = ds.get_groundtruth(k)
91+
for nprobe in 4, 16, 32:
92+
params.nprobe = nprobe
93+
_, I_res = faiss.search_with_parameters(index, ds.get_queries(), k, params)
94+
avg_speed = trials_ivf(index, ds.get_queries(), k, params)
95+
recall = compute_recall(gt_I, I_res)
96+
print(
97+
f"{prefix} nprobe={nprobe}: recall@{k}: {recall}. Average speed: {avg_speed:.1f}ms. Memory: {mem/1e6:.3f}MB"
98+
)
99+
recall_speed_data[prefix][0].append(recall)
100+
recall_speed_data[prefix][1].append(avg_speed)
101+
recall_speed_data[prefix][2].append(f"k={k}, nprobe={nprobe}")
102+
recall_memory_data[prefix][0].append(recall)
103+
recall_memory_data[prefix][1].append(mem)
104+
105+
106+
# pyre-ignore
107+
def vary_k_nprobe_measuring_recall_and_memory(prefix, index, ds, mem):
108+
classname = type(index).__name__
109+
for k in 1, 10, 100:
110+
if classname in [
111+
"IndexRaBitQ",
112+
"IndexPQFastScan",
113+
"IndexHNSWFlat",
114+
"IndexScalarQuantizer",
115+
]:
116+
handle_index(prefix, index, ds, mem, k)
117+
elif classname in [
118+
"IndexIVFRaBitQ",
119+
"IndexPreTransform",
120+
"IndexIVFPQFastScan",
121+
"IndexIVFScalarQuantizer",
122+
]:
123+
if (
124+
classname == "IndexIVFPQFastScan"
125+
or classname == "IndexIVFScalarQuantizer"
126+
):
127+
params = faiss.IVFSearchParameters()
128+
else:
129+
params = faiss.IVFRaBitQSearchParameters()
130+
params.qb = qb
131+
handle_ivf_index(prefix, index, ds, mem, k, params)
132+
133+
""":py '605360559215064'"""
134+
# IndexRaBitQ
135+
136+
fac_s = "RaBitQ"
137+
non_ivf_rbq = faiss.index_factory(ds.d, fac_s)
138+
non_ivf_rbq.qb = qb
139+
non_ivf_rbq.train(ds.get_train())
140+
non_ivf_rbq.add(ds.get_database())
141+
mem = non_ivf_rbq.code_size * non_ivf_rbq.ntotal
142+
143+
vary_k_nprobe_measuring_recall_and_memory(fac_s, non_ivf_rbq, ds, mem)
144+
145+
del non_ivf_rbq
146+
147+
""":py '3928150077498381'"""
148+
# IndexIVFRaBitQ with no random rotation
149+
150+
fac_s = f"IVF{nlist},RaBitQ"
151+
rbq1 = faiss.index_factory(ds.d, fac_s)
152+
rbq1.qb = qb
153+
rbq1.train(ds.get_train())
154+
rbq1.add(ds.get_database())
155+
mem = rbq1.code_size * rbq1.ntotal
156+
157+
vary_k_nprobe_measuring_recall_and_memory(fac_s, rbq1, ds, mem)
158+
159+
del rbq1
160+
161+
""":py '1484145352968190'"""
162+
# IndexIVFRaBitQ with random rotation
163+
164+
fac_s = f"IVF{nlist},RaBitQ"
165+
rbq2 = faiss.index_factory(ds.d, fac_s)
166+
rbq2.qb = qb
167+
rrot = faiss.RandomRotationMatrix(ds.d, ds.d)
168+
rrot.init(123)
169+
index_pt = faiss.IndexPreTransform(rrot, rbq2)
170+
index_pt.train(ds.get_train())
171+
index_pt.add(ds.get_database())
172+
mem = rbq2.code_size * index_pt.ntotal
173+
174+
vary_k_nprobe_measuring_recall_and_memory(fac_s + "_RROT", index_pt, ds, mem)
175+
176+
del index_pt
177+
178+
""":py '644702398382829'"""
179+
# IndexScalarQuantizer
180+
181+
for M in [4, 6, 8]:
182+
fac_s = f"SQ{M}"
183+
sq = create_index(ds, fac_s)
184+
mem = sq.code_size * sq.ntotal
185+
vary_k_nprobe_measuring_recall_and_memory("Index" + fac_s, sq, ds, mem)
186+
187+
""":py '1347502839702520'"""
188+
# IndexIVFScalarQuantizer
189+
190+
for M in [4, 6]: # 8 seems to have no recall improvement in this dataset.
191+
fac_s = f"IVF{nlist},SQ{M}"
192+
sq = create_index(ds, fac_s)
193+
mem = sq.code_size * sq.ntotal
194+
vary_k_nprobe_measuring_recall_and_memory(fac_s, sq, ds, mem)
195+
196+
""":py '1350039419637535'"""
197+
# PQFS
198+
199+
for m in [32, 64, 128]:
200+
fac_s = f"PQ{m}x4fs"
201+
pqfs = create_index(ds, fac_s)
202+
mem = pqfs.code_size * pqfs.ntotal
203+
vary_k_nprobe_measuring_recall_and_memory(fac_s, pqfs, ds, mem)
204+
del pqfs
205+
206+
""":py '2549074352105737'"""
207+
# IVFPQFS
208+
209+
for m in [32, 64, 128]:
210+
fac_s = f"IVF{nlist},PQ{m}x4fs"
211+
ivf_pqfs = create_index(ds, fac_s)
212+
mem = ivf_pqfs.code_size * ivf_pqfs.ntotal
213+
vary_k_nprobe_measuring_recall_and_memory(fac_s, ivf_pqfs, ds, mem)
214+
del ivf_pqfs
215+
216+
""":py '3933359133572530'"""
217+
# HNSW
218+
219+
for m in [8, 16, 32]:
220+
fac_s = f"HNSW{m}"
221+
index = create_index(ds, fac_s)
222+
storage = faiss.downcast_index(index.storage)
223+
mem = (
224+
storage.ntotal * storage.code_size
225+
+ index.hnsw.neighbors.size() * 4
226+
+ index.hnsw.offsets.size() * 8
227+
)
228+
vary_k_nprobe_measuring_recall_and_memory(fac_s, index, ds, mem)
229+
del index
230+
231+
""":py"""
232+
import matplotlib.pyplot as plt
233+
from adjustText import adjust_text
234+
235+
236+
# Specific colors that stand out against each other for this many data points.
237+
colors = [
238+
"black",
239+
"darkgray",
240+
"darkred",
241+
"red",
242+
"orange",
243+
"wheat",
244+
"olive",
245+
"yellow",
246+
"lime",
247+
"teal",
248+
"cyan",
249+
"skyblue",
250+
"royalblue",
251+
"navy",
252+
"darkviolet",
253+
"fuchsia",
254+
"deeppink",
255+
"pink",
256+
]
257+
258+
""":py '1023372579245229'"""
259+
slowest_speed = 0.0
260+
for key, vals in recall_speed_data.items():
261+
for speed in vals[1]:
262+
slowest_speed = max(slowest_speed, speed)
263+
264+
plt.axis([0, 1.0, 0, slowest_speed + 100.0]) # [xmin, xmax, ymin, ymax]
265+
for i, (key, vals) in enumerate(recall_speed_data.items()):
266+
recalls = vals[0]
267+
speeds = vals[1]
268+
plt.plot(
269+
recalls,
270+
speeds,
271+
linestyle=" ",
272+
marker="o",
273+
color=colors[i],
274+
label=key,
275+
markersize=15,
276+
)
277+
# Adding k and nprobe labels makes the diagram very busy, but can be enabled by uncommenting the following lines:
278+
# ks = vals[2]
279+
# texts = []
280+
# for i, (x_val, y_val) in enumerate(zip(recalls, speeds)):
281+
# texts.append(plt.text(x_val, y_val, ks[i]))
282+
# # Adjust text labels
283+
# adjust_text(
284+
# texts,
285+
# arrowprops=dict(arrowstyle="-", color="black", lw=0.5),
286+
# force_text=(0.1, 0.25),
287+
# force_points=(0.2, 0.5),
288+
# only_move={"points": "xy"},
289+
# )
290+
291+
plt.title("Recall vs Speed")
292+
plt.xlabel("Recall")
293+
plt.ylabel("Speed")
294+
plt.legend()
295+
plt.show()
296+
297+
""":py '1354989919068149'"""
298+
largest_mem = 0.0
299+
for key, vals in recall_memory_data.items():
300+
for mem in vals[1]:
301+
largest_mem = max(largest_mem, mem)
302+
303+
plt.ylim(1e6, 1e10)
304+
plt.yscale("log", base=10)
305+
306+
for i, (key, vals) in enumerate(recall_memory_data.items()):
307+
recalls = vals[0]
308+
mems = vals[1]
309+
plt.plot(
310+
recalls,
311+
mems,
312+
linestyle=" ",
313+
marker="o",
314+
color=colors[i],
315+
label=key,
316+
markersize=10,
317+
)
318+
319+
texts = []
320+
if i == 0:
321+
texts.append(plt.text(recalls[0], mems[0], "RaBitQ"))
322+
texts.append(plt.text(recalls[1], mems[1], "RaBitQ"))
323+
adjust_text(
324+
texts,
325+
arrowprops=dict(arrowstyle="-", color="black", lw=0.5),
326+
force_text=(0.5, 0.25),
327+
force_points=(1.0, 1.5),
328+
expand_points=(5.0, 10.0),
329+
)
330+
331+
plt.title("Recall vs Memory")
332+
plt.xlabel("Recall")
333+
plt.ylabel("Memory")
334+
plt.legend()
335+
plt.show()
336+
337+
""":py"""

0 commit comments

Comments
 (0)