11
11
// under the License.
12
12
//
13
13
14
+ #include < queue>
15
+
14
16
#include " yb/client/snapshot_test_util.h"
15
17
16
18
#include " yb/consensus/consensus.h"
28
30
#include " yb/util/backoff_waiter.h"
29
31
#include " yb/util/test_thread_holder.h"
30
32
33
+ #include " yb/vector_index/usearch_include_wrapper_internal.h"
34
+
31
35
#include " yb/yql/pgwrapper/pg_mini_test_base.h"
32
36
33
37
DECLARE_bool (TEST_skip_process_apply);
@@ -38,6 +42,12 @@ DECLARE_uint32(vector_index_concurrent_writes);
38
42
39
43
namespace yb ::pgwrapper {
40
44
45
+ using FloatVector = std::vector<float >;
46
+
47
+ const unum::usearch::byte_t * VectorToBytePtr (const FloatVector& vector) {
48
+ return pointer_cast<const unum::usearch::byte_t *>(vector.data ());
49
+ }
50
+
41
51
class PgVectorIndexTest : public PgMiniTestBase , public testing ::WithParamInterface<bool > {
42
52
protected:
43
53
void SetUp () override {
@@ -57,43 +67,45 @@ class PgVectorIndexTest : public PgMiniTestBase, public testing::WithParamInterf
57
67
return IsColocated () ? ConnectToDB (" colocated_db" ) : PgMiniTestBase::Connect ();
58
68
}
59
69
60
- Result<PGConn> MakeIndex (int num_tablets = 0 ) {
70
+ Result<PGConn> MakeIndex (size_t dimensions = 3 ) {
61
71
auto colocated = IsColocated ();
62
72
auto conn = VERIFY_RESULT (PgMiniTestBase::Connect ());
63
73
std::string create_suffix;
64
74
if (colocated) {
65
75
create_suffix = " WITH (COLOCATED = 1)" ;
66
76
RETURN_NOT_OK (conn.ExecuteFormat (" CREATE DATABASE colocated_db COLOCATION = true" ));
67
77
conn = VERIFY_RESULT (Connect ());
68
- } else if (num_tablets) {
69
- create_suffix = Format (" SPLIT INTO $0 TABLETS" , num_tablets);
70
78
}
71
79
RETURN_NOT_OK (conn.Execute (" CREATE EXTENSION vector" ));
72
- RETURN_NOT_OK (conn.Execute (
73
- " CREATE TABLE test (id bigserial PRIMARY KEY, embedding vector(3))" + create_suffix));
80
+ RETURN_NOT_OK (conn.ExecuteFormat (
81
+ " CREATE TABLE test (id bigserial PRIMARY KEY, embedding vector($0))$1" ,
82
+ dimensions, create_suffix));
74
83
75
84
RETURN_NOT_OK (conn.Execute (" CREATE INDEX ON test USING ybhnsw (embedding vector_l2_ops)" ));
76
85
77
86
return conn;
78
87
}
79
88
80
- Status WaitForLoadBalance (int num_tablet_servers) {
81
- return WaitFor (
82
- [&]() -> Result<bool > { return client_->IsLoadBalanced (num_tablet_servers); },
83
- 60s * kTimeMultiplier ,
84
- Format (" Wait for load balancer to balance to $0 tservers." , num_tablet_servers));
85
- }
86
-
87
- Result<PGConn> MakeIndexAndFill (int num_rows, int num_tablets = 0 );
88
- Status InsertRows (PGConn& conn, int start_row, int end_row);
89
+ Result<PGConn> MakeIndexAndFill (size_t num_rows);
90
+ Result<PGConn> MakeIndexAndFillRandom (size_t num_rows, size_t dimensions);
91
+ Status InsertRows (PGConn& conn, size_t start_row, size_t end_row);
92
+ Status InsertRandomRows (PGConn& conn, size_t num_rows, size_t dimensions);
89
93
90
- void VerifyRead (PGConn& conn, int limit, bool add_filter);
94
+ void VerifyRead (PGConn& conn, size_t limit, bool add_filter);
91
95
void VerifyRows (
92
- PGConn& conn, bool add_filter, const std::vector<std::string>& expected, int limit = -1 );
96
+ PGConn& conn, bool add_filter, const std::vector<std::string>& expected, int64_t limit = -1 );
93
97
94
98
void TestSimple ();
95
99
void TestManyRows (bool add_filter);
96
100
void TestRestart (tablet::FlushFlags flush_flags);
101
+
102
+ FloatVector RandomVector (size_t dimensions) {
103
+ return RandomFloatVector (dimensions, distribution_, &rng_);
104
+ }
105
+
106
+ std::vector<FloatVector> vectors_;
107
+ std::uniform_real_distribution<> distribution_;
108
+ std::mt19937_64 rng_{42 };
97
109
};
98
110
99
111
void PgVectorIndexTest::TestSimple () {
@@ -167,53 +179,70 @@ std::string ExpectedRow(int64_t id) {
167
179
return BuildRow (id, VectorAsString (id));
168
180
}
169
181
170
- Status PgVectorIndexTest::InsertRows (PGConn& conn, int start_row, int end_row) {
182
+ Status PgVectorIndexTest::InsertRows (PGConn& conn, size_t start_row, size_t end_row) {
171
183
RETURN_NOT_OK (conn.StartTransaction (IsolationLevel::SNAPSHOT_ISOLATION));
172
- for (int i = start_row; i <= end_row; ++i) {
184
+ for (auto i = start_row; i <= end_row; ++i) {
173
185
RETURN_NOT_OK (conn.ExecuteFormat (
174
186
" INSERT INTO test VALUES ($0, '$1')" , i, VectorAsString (i)));
175
187
}
176
188
return conn.CommitTransaction ();
177
189
}
178
190
179
- Result<PGConn> PgVectorIndexTest::MakeIndexAndFill (int num_rows, int num_tablets) {
180
- auto conn = VERIFY_RESULT (MakeIndex (num_tablets));
191
+ Status PgVectorIndexTest::InsertRandomRows (PGConn& conn, size_t num_rows, size_t dimensions) {
192
+ RETURN_NOT_OK (conn.StartTransaction (IsolationLevel::SNAPSHOT_ISOLATION));
193
+ for (size_t i = 0 ; i != num_rows; ++i) {
194
+ auto vector = RandomVector (dimensions);
195
+ RETURN_NOT_OK (conn.ExecuteFormat (
196
+ " INSERT INTO test VALUES ($0, '$1')" , vectors_.size (), AsString (vector)));
197
+ vectors_.push_back (std::move (vector));
198
+ }
199
+ return conn.CommitTransaction ();
200
+ }
201
+
202
+ Result<PGConn> PgVectorIndexTest::MakeIndexAndFill (size_t num_rows) {
203
+ auto conn = VERIFY_RESULT (MakeIndex ());
181
204
RETURN_NOT_OK (InsertRows (conn, 1 , num_rows));
182
205
return conn;
183
206
}
184
207
208
+ Result<PGConn> PgVectorIndexTest::MakeIndexAndFillRandom (size_t num_rows, size_t dimensions) {
209
+ auto conn = VERIFY_RESULT (MakeIndex (dimensions));
210
+ RETURN_NOT_OK (InsertRandomRows (conn, num_rows, dimensions));
211
+ return conn;
212
+ }
213
+
185
214
void PgVectorIndexTest::VerifyRows (
186
- PGConn& conn, bool add_filter, const std::vector<std::string>& expected, int limit) {
215
+ PGConn& conn, bool add_filter, const std::vector<std::string>& expected, int64_t limit) {
187
216
auto result = ASSERT_RESULT ((conn.FetchRows <RowAsString>(Format (
188
217
" SELECT * FROM test $0 ORDER BY embedding <-> '[0.0, 0.0, 0.0]' LIMIT $1" ,
189
218
add_filter ? " WHERE id + 3 <= 5" : " " ,
190
- limit == - 1 ? expected.size () : make_unsigned (limit)))));
219
+ limit < 0 ? expected.size () : make_unsigned (limit)))));
191
220
EXPECT_EQ (result.size (), expected.size ());
192
221
for (size_t i = 0 ; i != std::min (result.size (), expected.size ()); ++i) {
193
222
SCOPED_TRACE (Format (" Row $0" , i));
194
223
EXPECT_EQ (result[i], expected[i]);
195
224
}
196
225
}
197
226
198
- void PgVectorIndexTest::VerifyRead (PGConn& conn, int limit, bool add_filter) {
227
+ void PgVectorIndexTest::VerifyRead (PGConn& conn, size_t limit, bool add_filter) {
199
228
std::vector<std::string> expected;
200
- for (int i = 1 ; i <= limit; ++i) {
229
+ for (size_t i = 1 ; i <= limit; ++i) {
201
230
expected.push_back (ExpectedRow (i));
202
231
}
203
232
VerifyRows (conn, add_filter, expected);
204
233
}
205
234
206
235
void PgVectorIndexTest::TestManyRows (bool add_filter) {
207
- constexpr int kNumRows = RegularBuildVsSanitizers (2000 , 64 );
208
- const int query_limit = add_filter ? 1 : 5 ;
236
+ constexpr size_t kNumRows = RegularBuildVsSanitizers (2000 , 64 );
237
+ const size_t query_limit = add_filter ? 1 : 5 ;
209
238
210
239
auto conn = ASSERT_RESULT (MakeIndexAndFill (kNumRows ));
211
240
ASSERT_NO_FATALS (VerifyRead (conn, query_limit, add_filter));
212
241
}
213
242
214
243
TEST_P (PgVectorIndexTest, Split) {
215
- constexpr int kNumRows = RegularBuildVsSanitizers (500 , 64 );
216
- constexpr int kQueryLimit = 5 ;
244
+ constexpr size_t kNumRows = RegularBuildVsSanitizers (500 , 64 );
245
+ constexpr size_t kQueryLimit = 5 ;
217
246
218
247
auto conn = ASSERT_RESULT (MakeIndexAndFill (kNumRows ));
219
248
ASSERT_OK (cluster_->FlushTablets ());
@@ -236,17 +265,17 @@ TEST_P(PgVectorIndexTest, ManyReads) {
236
265
ANNOTATE_UNPROTECTED_WRITE (FLAGS_vector_index_concurrent_reads) = 1 ;
237
266
ANNOTATE_UNPROTECTED_WRITE (FLAGS_vector_index_concurrent_writes) = 1 ;
238
267
239
- constexpr int kNumRows = 64 ;
240
- constexpr int kNumReads = 16 ;
268
+ constexpr size_t kNumRows = 64 ;
269
+ constexpr size_t kNumReads = 16 ;
241
270
242
271
auto conn = ASSERT_RESULT (MakeIndexAndFill (kNumRows ));
243
272
244
273
TestThreadHolder threads;
245
- for (int i = 1 ; i <= kNumReads ; ++i) {
274
+ for (size_t i = 1 ; i <= kNumReads ; ++i) {
246
275
threads.AddThreadFunctor ([this , &stop_flag = threads.stop_flag ()] {
247
276
auto conn = ASSERT_RESULT (Connect ());
248
277
while (!stop_flag.load ()) {
249
- auto id = RandomUniformInt (1 , kNumRows );
278
+ auto id = RandomUniformInt< size_t > (1 , kNumRows );
250
279
auto vector = VectorAsString (id);
251
280
auto rows = ASSERT_RESULT (conn.FetchAllAsString (Format (
252
281
" SELECT * FROM test ORDER BY embedding <-> '$0' LIMIT 1" , vector)));
@@ -259,8 +288,8 @@ TEST_P(PgVectorIndexTest, ManyReads) {
259
288
}
260
289
261
290
void PgVectorIndexTest::TestRestart (tablet::FlushFlags flush_flags) {
262
- constexpr int kNumRows = 64 ;
263
- constexpr int kQueryLimit = 5 ;
291
+ constexpr size_t kNumRows = 64 ;
292
+ constexpr size_t kQueryLimit = 5 ;
264
293
265
294
auto conn = ASSERT_RESULT (MakeIndexAndFill (kNumRows ));
266
295
ASSERT_NO_FATALS (VerifyRead (conn, kQueryLimit , false ));
@@ -284,7 +313,7 @@ TEST_P(PgVectorIndexTest, BootstrapFlushedIntentsDB) {
284
313
}
285
314
286
315
TEST_P (PgVectorIndexTest, DeleteAndUpdate) {
287
- constexpr int kNumRows = 64 ;
316
+ constexpr size_t kNumRows = 64 ;
288
317
const std::string kDistantVector = " [100, 500, 9000]" ;
289
318
const std::string kCloseVector = " [0.125, 0.25, 0.375]" ;
290
319
@@ -309,12 +338,12 @@ TEST_P(PgVectorIndexTest, DeleteAndUpdate) {
309
338
}
310
339
311
340
TEST_P (PgVectorIndexTest, RemoteBootstrap) {
312
- constexpr int kNumRows = 64 ;
313
- constexpr int kQueryLimit = 5 ;
341
+ constexpr size_t kNumRows = 64 ;
342
+ constexpr size_t kQueryLimit = 5 ;
314
343
315
344
auto * mts = cluster_->mini_tablet_server (2 );
316
345
mts->Shutdown ();
317
- auto conn = ASSERT_RESULT (MakeIndexAndFill (kNumRows , 3 ));
346
+ auto conn = ASSERT_RESULT (MakeIndexAndFill (kNumRows ));
318
347
const auto table_id = ASSERT_RESULT (GetTableIDFromTableName (" test" ));
319
348
ASSERT_OK (cluster_->FlushTablets ());
320
349
for (const auto & peer : ListTableActiveTabletPeers (cluster_.get (), table_id)) {
@@ -355,8 +384,8 @@ TEST_P(PgVectorIndexTest, RemoteBootstrap) {
355
384
}
356
385
357
386
TEST_P (PgVectorIndexTest, SnapshotSchedule) {
358
- constexpr int kNumRows = 128 ;
359
- constexpr int kQueryLimit = 5 ;
387
+ constexpr size_t kNumRows = 128 ;
388
+ constexpr size_t kQueryLimit = 5 ;
360
389
361
390
client::SnapshotTestUtil snapshot_util;
362
391
snapshot_util.SetProxy (&client_->proxy_cache ());
@@ -383,6 +412,58 @@ TEST_P(PgVectorIndexTest, SnapshotSchedule) {
383
412
ASSERT_NO_FATALS (VerifyRead (conn, kQueryLimit , false ));
384
413
}
385
414
415
+ TEST_P (PgVectorIndexTest, Random) {
416
+ constexpr size_t kLimit = 10 ;
417
+ constexpr size_t kDimensions = 64 ;
418
+ constexpr size_t kNumRows = RegularBuildVsDebugVsSanitizers (10000 , 1000 , 100 );
419
+ constexpr int kNumIterations = RegularBuildVsDebugVsSanitizers (100 , 20 , 10 );
420
+
421
+ unum::usearch::metric_punned_t metric (
422
+ kDimensions , unum::usearch::metric_kind_t ::l2sq_k, unum::usearch::scalar_kind_t ::f32_k);
423
+
424
+ auto conn = ASSERT_RESULT (MakeIndexAndFillRandom (kNumRows , kDimensions ));
425
+ size_t sum_missing = 0 ;
426
+ std::vector<size_t > counts;
427
+ for (int i = 0 ; i != kNumIterations ; ++i) {
428
+ auto query_vector = RandomVector (kDimensions );
429
+ auto rows = ASSERT_RESULT (conn.FetchRows <int64_t >(Format (
430
+ " SELECT id FROM test ORDER BY embedding <-> '$0' LIMIT $1" , query_vector, kLimit )));
431
+ std::vector<int64_t > expected (vectors_.size ());
432
+ std::generate (expected.begin (), expected.end (), [n{0LL }]() mutable { return n++; });
433
+ std::sort (
434
+ expected.begin (), expected.end (),
435
+ [&metric, &query_vector, &vectors = vectors_](size_t li, size_t ri) {
436
+ const auto & lhs = vectors[li];
437
+ const auto & rhs = vectors[ri];
438
+ return metric (VectorToBytePtr (query_vector), VectorToBytePtr (lhs)) <
439
+ metric (VectorToBytePtr (query_vector), VectorToBytePtr (rhs));
440
+ });
441
+ size_t ep = 0 ;
442
+ for (int64_t id : rows) {
443
+ while (ep < expected.size () && id != expected[ep]) {
444
+ ++ep;
445
+ }
446
+ ASSERT_LT (ep, expected.size ());
447
+ ASSERT_EQ (id, expected[ep]);
448
+ ++ep;
449
+ }
450
+ size_t missing = ep - kLimit ;
451
+ if (missing > counts.size ()) {
452
+ LOG (INFO)
453
+ << " New max: " << missing << " , fetched: " << AsString (rows) << " , expected: "
454
+ << AsString (boost::make_iterator_range (
455
+ expected.begin (), expected.begin () + kLimit + missing));
456
+ }
457
+ counts.resize (std::max (counts.size (), missing + 1 ));
458
+ ++counts[missing];
459
+ sum_missing += missing;
460
+ }
461
+ LOG (INFO)
462
+ << " Counts: " << AsString (counts)
463
+ << " , recall: " << 1.0 - sum_missing * 1.0 / (kLimit * kNumIterations );
464
+ ASSERT_LE (sum_missing * 50 , kLimit * kNumIterations );
465
+ }
466
+
386
467
std::string ColocatedToString (const testing::TestParamInfo<bool >& param_info) {
387
468
return param_info.param ? " Colocated" : " Distributed" ;
388
469
}
0 commit comments