Skip to content

Commit dfbb945

Browse files
committed
Vectory type insert/update filter
1 parent 3e6a2d4 commit dfbb945

File tree

3 files changed

+154
-2
lines changed

3 files changed

+154
-2
lines changed

builder_insert.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ func (b *InsertBuilder) Set(k string, v interface{}) *InsertBuilder {
6262
case time.Time:
6363
ts := v.(time.Time).Format("2006-01-02 15:04:05")
6464
v = ts
65+
case Vector:
66+
// Vector 类型:不做处理,保持原样
67+
// 让 database/sql 调用 driver.Valuer 接口
68+
// Vector.Value() 会返回正确的数据库格式
6569
case interface{}:
6670
bytes, _ := json.Marshal(v)
6771
v = string(bytes)

builder_update.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ func (ub *UpdateBuilder) Set(k string, v interface{}) *UpdateBuilder {
6565
case time.Time:
6666
ts := v.(time.Time).Format("2006-01-02 15:04:05")
6767
v = ts
68+
case Vector:
69+
// Vector 类型:不做处理,保持原样
70+
// 让 database/sql 调用 driver.Valuer 接口
71+
// Vector.Value() 会返回正确的数据库格式
6872
case interface{}:
6973
bytes, _ := json.Marshal(v)
7074
v = string(bytes)

vector_test.go

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@ func TestVectorSearch_Basic(t *testing.T) {
4444
Build().
4545
SqlOfVectorSearch()
4646

47+
t.Logf("=== SELECT 向量检索测试 ===")
48+
t.Logf("SQL: %s", sql)
49+
t.Logf("Args count: %d", len(args))
50+
if len(args) > 0 {
51+
t.Logf("Args[0] type: %T", args[0])
52+
t.Logf("Args[0] value: %v", args[0])
53+
54+
// 检查查询参数类型
55+
switch args[0].(type) {
56+
case Vector:
57+
t.Logf("✅ 查询参数是 Vector 类型,driver.Valuer 会被调用")
58+
case string:
59+
t.Logf("⚠️ 查询参数是 string 类型")
60+
default:
61+
t.Logf("❓ 查询参数是未知类型: %T", args[0])
62+
}
63+
}
64+
4765
expectedSQL := "SELECT *, embedding <-> ? AS distance FROM code_vectors ORDER BY distance LIMIT 10"
4866

4967
if sql != expectedSQL {
@@ -66,6 +84,9 @@ func TestVectorSearch_WithScalarFilter(t *testing.T) {
6684
Build().
6785
SqlOfVectorSearch()
6886

87+
t.Logf("SQL: %s", sql)
88+
t.Logf("Args: %d", len(args))
89+
6990
// SQL 应该包含 WHERE 条件
7091
if !containsString(sql, "WHERE") {
7192
t.Errorf("Expected WHERE clause in SQL: %s", sql)
@@ -93,12 +114,16 @@ func TestVectorSearch_WithScalarFilter(t *testing.T) {
93114
func TestVectorSearch_L2Distance(t *testing.T) {
94115
queryVector := Vector{0.1, 0.2, 0.3}
95116

96-
sql, _ := Of(&CodeVector{}).
117+
sql, args := Of(&CodeVector{}).
97118
VectorSearch("embedding", queryVector, 10).
98119
VectorDistance(L2Distance).
99120
Build().
100121
SqlOfVectorSearch()
101122

123+
t.Logf("SQL: %s", sql)
124+
t.Logf("Distance Metric: L2Distance (<#>)")
125+
t.Logf("Args: %d", len(args))
126+
102127
// SQL 应该使用 <#> 运算符
103128
if !containsString(sql, "<#>") {
104129
t.Errorf("Expected <#> (L2 distance) in SQL: %s", sql)
@@ -115,6 +140,10 @@ func TestVectorDistanceFilter(t *testing.T) {
115140
Build().
116141
SqlOfVectorSearch()
117142

143+
t.Logf("SQL: %s", sql)
144+
t.Logf("Threshold: < 0.3")
145+
t.Logf("Args: %d", len(args))
146+
118147
// SQL 应该包含距离过滤条件
119148
if !containsString(sql, "<-> ?") {
120149
t.Errorf("Expected distance filter in SQL: %s", sql)
@@ -143,6 +172,10 @@ func TestVectorSearch_AutoIgnoreNil(t *testing.T) {
143172
Build().
144173
SqlOfVectorSearch()
145174

175+
t.Logf("SQL: %s", sql)
176+
t.Logf("Note: Empty language auto-ignored")
177+
t.Logf("Args: %d", len(args))
178+
146179
// SQL 不应该包含 language(因为是空字符串)
147180
// 但应该包含 layer
148181
if containsString(sql, "language") {
@@ -166,12 +199,14 @@ func TestVector_Distance(t *testing.T) {
166199

167200
// 余弦距离
168201
cosDist := vec1.Distance(vec2, CosineDistance)
202+
t.Logf("Cosine Distance: %.4f", cosDist)
169203
if cosDist != 1.0 {
170204
t.Errorf("Expected cosine distance 1.0, got %f", cosDist)
171205
}
172206

173207
// L2 距离
174208
l2Dist := vec1.Distance(vec2, L2Distance)
209+
t.Logf("L2 Distance: %.4f", l2Dist)
175210
expected := float32(1.414213) // sqrt(2)
176211
if abs(l2Dist-expected) > 0.001 {
177212
t.Errorf("Expected L2 distance ~1.414, got %f", l2Dist)
@@ -183,6 +218,9 @@ func TestVector_Normalize(t *testing.T) {
183218
vec := Vector{3.0, 4.0} // 长度为 5
184219
normalized := vec.Normalize()
185220

221+
t.Logf("Original: %v", vec)
222+
t.Logf("Normalized: %v", normalized)
223+
186224
// 归一化后长度应该为 1
187225
expected := Vector{0.6, 0.8}
188226

@@ -195,6 +233,113 @@ func TestVector_Normalize(t *testing.T) {
195233
}
196234
}
197235

236+
// 测试向量插入
237+
func TestVector_Insert(t *testing.T) {
238+
code := &CodeVector{
239+
Content: "func main() { fmt.Println(\"Hello\") }",
240+
Embedding: Vector{0.1, 0.2, 0.3, 0.4},
241+
Language: "golang",
242+
Layer: "main",
243+
}
244+
245+
sql, args := Of(code).
246+
Insert(func(ib *InsertBuilder) {
247+
ib.Set("content", code.Content).
248+
Set("embedding", code.Embedding).
249+
Set("language", code.Language).
250+
Set("layer", code.Layer)
251+
}).
252+
Build().
253+
SqlOfInsert()
254+
255+
t.Logf("=== INSERT 测试 ===")
256+
t.Logf("SQL: %s", sql)
257+
t.Logf("Args count: %d", len(args))
258+
for i, arg := range args {
259+
t.Logf("Args[%d] type: %T", i, arg)
260+
t.Logf("Args[%d] value: %v", i, arg)
261+
}
262+
263+
// 验证 SQL 包含所有字段
264+
if !containsString(sql, "content") {
265+
t.Errorf("Expected content field in SQL: %s", sql)
266+
}
267+
if !containsString(sql, "embedding") {
268+
t.Errorf("Expected embedding field in SQL: %s", sql)
269+
}
270+
if !containsString(sql, "language") {
271+
t.Errorf("Expected language field in SQL: %s", sql)
272+
}
273+
}
274+
275+
// 测试向量更新
276+
func TestVector_Update(t *testing.T) {
277+
newEmbedding := Vector{0.5, 0.6, 0.7, 0.8}
278+
279+
sql, args := Of(&CodeVector{}).
280+
Update(func(ub *UpdateBuilder) {
281+
ub.Set("embedding", newEmbedding).
282+
Set("language", "golang")
283+
}).
284+
Eq("id", 123).
285+
Build().
286+
SqlOfUpdate()
287+
288+
t.Logf("=== UPDATE 测试 ===")
289+
t.Logf("SQL: %s", sql)
290+
t.Logf("Args count: %d", len(args))
291+
for i, arg := range args {
292+
t.Logf("Args[%d] type: %T", i, arg)
293+
t.Logf("Args[%d] value: %v", i, arg)
294+
}
295+
296+
// 验证 SQL
297+
if !containsString(sql, "UPDATE") {
298+
t.Errorf("Expected UPDATE in SQL: %s", sql)
299+
}
300+
if !containsString(sql, "embedding") {
301+
t.Errorf("Expected embedding field in SQL: %s", sql)
302+
}
303+
if !containsString(sql, "WHERE") {
304+
t.Errorf("Expected WHERE clause in SQL: %s", sql)
305+
}
306+
}
307+
308+
// 测试向量类型在 Set() 中的处理
309+
func TestVector_SetBehavior(t *testing.T) {
310+
vec := Vector{1.0, 2.0, 3.0}
311+
312+
// 测试 InsertBuilder.Set()
313+
sql, args := Of(&CodeVector{}).
314+
Insert(func(ib *InsertBuilder) {
315+
ib.Set("embedding", vec)
316+
}).
317+
Build().
318+
SqlOfInsert()
319+
320+
t.Logf("=== Vector Set() 行为测试 ===")
321+
t.Logf("原始 Vector: %v (类型: %T)", vec, vec)
322+
t.Logf("SQL: %s", sql)
323+
324+
if len(args) > 0 {
325+
t.Logf("Set() 后 args[0] 类型: %T", args[0])
326+
t.Logf("Set() 后 args[0] 值: %v", args[0])
327+
328+
// 关键检查:args[0] 是 Vector 还是 string?
329+
switch args[0].(type) {
330+
case Vector:
331+
t.Logf("✅ args[0] 是 Vector 类型,driver.Valuer 会被调用")
332+
case string:
333+
t.Logf("⚠️ args[0] 是 string 类型,已被 JSON Marshal")
334+
t.Logf("⚠️ driver.Valuer 不会被调用")
335+
case []float32:
336+
t.Logf("✅ args[0] 是 []float32 类型")
337+
default:
338+
t.Logf("❓ args[0] 是未知类型: %T", args[0])
339+
}
340+
}
341+
}
342+
198343
// 辅助函数
199344
func containsString(s, substr string) bool {
200345
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
@@ -215,4 +360,3 @@ func abs(x float32) float32 {
215360
}
216361
return x
217362
}
218-

0 commit comments

Comments
 (0)