@@ -44,6 +44,24 @@ func TestVectorSearch_Basic(t *testing.T) {
4444 Build ().
4545 SqlOfVectorSearch ()
4646
47+ t .Logf ("=== SELECT 向量检索测试 ===" )
48+ t .Logf ("SQL: %s" , sql )
49+ t .Logf ("Args count: %d" , len (args ))
50+ if len (args ) > 0 {
51+ t .Logf ("Args[0] type: %T" , args [0 ])
52+ t .Logf ("Args[0] value: %v" , args [0 ])
53+
54+ // 检查查询参数类型
55+ switch args [0 ].(type ) {
56+ case Vector :
57+ t .Logf ("✅ 查询参数是 Vector 类型,driver.Valuer 会被调用" )
58+ case string :
59+ t .Logf ("⚠️ 查询参数是 string 类型" )
60+ default :
61+ t .Logf ("❓ 查询参数是未知类型: %T" , args [0 ])
62+ }
63+ }
64+
4765 expectedSQL := "SELECT *, embedding <-> ? AS distance FROM code_vectors ORDER BY distance LIMIT 10"
4866
4967 if sql != expectedSQL {
@@ -66,6 +84,9 @@ func TestVectorSearch_WithScalarFilter(t *testing.T) {
6684 Build ().
6785 SqlOfVectorSearch ()
6886
87+ t .Logf ("SQL: %s" , sql )
88+ t .Logf ("Args: %d" , len (args ))
89+
6990 // SQL 应该包含 WHERE 条件
7091 if ! containsString (sql , "WHERE" ) {
7192 t .Errorf ("Expected WHERE clause in SQL: %s" , sql )
@@ -93,12 +114,16 @@ func TestVectorSearch_WithScalarFilter(t *testing.T) {
93114func TestVectorSearch_L2Distance (t * testing.T ) {
94115 queryVector := Vector {0.1 , 0.2 , 0.3 }
95116
96- sql , _ := Of (& CodeVector {}).
117+ sql , args := Of (& CodeVector {}).
97118 VectorSearch ("embedding" , queryVector , 10 ).
98119 VectorDistance (L2Distance ).
99120 Build ().
100121 SqlOfVectorSearch ()
101122
123+ t .Logf ("SQL: %s" , sql )
124+ t .Logf ("Distance Metric: L2Distance (<#>)" )
125+ t .Logf ("Args: %d" , len (args ))
126+
102127 // SQL 应该使用 <#> 运算符
103128 if ! containsString (sql , "<#>" ) {
104129 t .Errorf ("Expected <#> (L2 distance) in SQL: %s" , sql )
@@ -115,6 +140,10 @@ func TestVectorDistanceFilter(t *testing.T) {
115140 Build ().
116141 SqlOfVectorSearch ()
117142
143+ t .Logf ("SQL: %s" , sql )
144+ t .Logf ("Threshold: < 0.3" )
145+ t .Logf ("Args: %d" , len (args ))
146+
118147 // SQL 应该包含距离过滤条件
119148 if ! containsString (sql , "<-> ?" ) {
120149 t .Errorf ("Expected distance filter in SQL: %s" , sql )
@@ -143,6 +172,10 @@ func TestVectorSearch_AutoIgnoreNil(t *testing.T) {
143172 Build ().
144173 SqlOfVectorSearch ()
145174
175+ t .Logf ("SQL: %s" , sql )
176+ t .Logf ("Note: Empty language auto-ignored" )
177+ t .Logf ("Args: %d" , len (args ))
178+
146179 // SQL 不应该包含 language(因为是空字符串)
147180 // 但应该包含 layer
148181 if containsString (sql , "language" ) {
@@ -166,12 +199,14 @@ func TestVector_Distance(t *testing.T) {
166199
167200 // 余弦距离
168201 cosDist := vec1 .Distance (vec2 , CosineDistance )
202+ t .Logf ("Cosine Distance: %.4f" , cosDist )
169203 if cosDist != 1.0 {
170204 t .Errorf ("Expected cosine distance 1.0, got %f" , cosDist )
171205 }
172206
173207 // L2 距离
174208 l2Dist := vec1 .Distance (vec2 , L2Distance )
209+ t .Logf ("L2 Distance: %.4f" , l2Dist )
175210 expected := float32 (1.414213 ) // sqrt(2)
176211 if abs (l2Dist - expected ) > 0.001 {
177212 t .Errorf ("Expected L2 distance ~1.414, got %f" , l2Dist )
@@ -183,6 +218,9 @@ func TestVector_Normalize(t *testing.T) {
183218 vec := Vector {3.0 , 4.0 } // 长度为 5
184219 normalized := vec .Normalize ()
185220
221+ t .Logf ("Original: %v" , vec )
222+ t .Logf ("Normalized: %v" , normalized )
223+
186224 // 归一化后长度应该为 1
187225 expected := Vector {0.6 , 0.8 }
188226
@@ -195,6 +233,113 @@ func TestVector_Normalize(t *testing.T) {
195233 }
196234}
197235
236+ // 测试向量插入
237+ func TestVector_Insert (t * testing.T ) {
238+ code := & CodeVector {
239+ Content : "func main() { fmt.Println(\" Hello\" ) }" ,
240+ Embedding : Vector {0.1 , 0.2 , 0.3 , 0.4 },
241+ Language : "golang" ,
242+ Layer : "main" ,
243+ }
244+
245+ sql , args := Of (code ).
246+ Insert (func (ib * InsertBuilder ) {
247+ ib .Set ("content" , code .Content ).
248+ Set ("embedding" , code .Embedding ).
249+ Set ("language" , code .Language ).
250+ Set ("layer" , code .Layer )
251+ }).
252+ Build ().
253+ SqlOfInsert ()
254+
255+ t .Logf ("=== INSERT 测试 ===" )
256+ t .Logf ("SQL: %s" , sql )
257+ t .Logf ("Args count: %d" , len (args ))
258+ for i , arg := range args {
259+ t .Logf ("Args[%d] type: %T" , i , arg )
260+ t .Logf ("Args[%d] value: %v" , i , arg )
261+ }
262+
263+ // 验证 SQL 包含所有字段
264+ if ! containsString (sql , "content" ) {
265+ t .Errorf ("Expected content field in SQL: %s" , sql )
266+ }
267+ if ! containsString (sql , "embedding" ) {
268+ t .Errorf ("Expected embedding field in SQL: %s" , sql )
269+ }
270+ if ! containsString (sql , "language" ) {
271+ t .Errorf ("Expected language field in SQL: %s" , sql )
272+ }
273+ }
274+
275+ // 测试向量更新
276+ func TestVector_Update (t * testing.T ) {
277+ newEmbedding := Vector {0.5 , 0.6 , 0.7 , 0.8 }
278+
279+ sql , args := Of (& CodeVector {}).
280+ Update (func (ub * UpdateBuilder ) {
281+ ub .Set ("embedding" , newEmbedding ).
282+ Set ("language" , "golang" )
283+ }).
284+ Eq ("id" , 123 ).
285+ Build ().
286+ SqlOfUpdate ()
287+
288+ t .Logf ("=== UPDATE 测试 ===" )
289+ t .Logf ("SQL: %s" , sql )
290+ t .Logf ("Args count: %d" , len (args ))
291+ for i , arg := range args {
292+ t .Logf ("Args[%d] type: %T" , i , arg )
293+ t .Logf ("Args[%d] value: %v" , i , arg )
294+ }
295+
296+ // 验证 SQL
297+ if ! containsString (sql , "UPDATE" ) {
298+ t .Errorf ("Expected UPDATE in SQL: %s" , sql )
299+ }
300+ if ! containsString (sql , "embedding" ) {
301+ t .Errorf ("Expected embedding field in SQL: %s" , sql )
302+ }
303+ if ! containsString (sql , "WHERE" ) {
304+ t .Errorf ("Expected WHERE clause in SQL: %s" , sql )
305+ }
306+ }
307+
308+ // 测试向量类型在 Set() 中的处理
309+ func TestVector_SetBehavior (t * testing.T ) {
310+ vec := Vector {1.0 , 2.0 , 3.0 }
311+
312+ // 测试 InsertBuilder.Set()
313+ sql , args := Of (& CodeVector {}).
314+ Insert (func (ib * InsertBuilder ) {
315+ ib .Set ("embedding" , vec )
316+ }).
317+ Build ().
318+ SqlOfInsert ()
319+
320+ t .Logf ("=== Vector Set() 行为测试 ===" )
321+ t .Logf ("原始 Vector: %v (类型: %T)" , vec , vec )
322+ t .Logf ("SQL: %s" , sql )
323+
324+ if len (args ) > 0 {
325+ t .Logf ("Set() 后 args[0] 类型: %T" , args [0 ])
326+ t .Logf ("Set() 后 args[0] 值: %v" , args [0 ])
327+
328+ // 关键检查:args[0] 是 Vector 还是 string?
329+ switch args [0 ].(type ) {
330+ case Vector :
331+ t .Logf ("✅ args[0] 是 Vector 类型,driver.Valuer 会被调用" )
332+ case string :
333+ t .Logf ("⚠️ args[0] 是 string 类型,已被 JSON Marshal" )
334+ t .Logf ("⚠️ driver.Valuer 不会被调用" )
335+ case []float32 :
336+ t .Logf ("✅ args[0] 是 []float32 类型" )
337+ default :
338+ t .Logf ("❓ args[0] 是未知类型: %T" , args [0 ])
339+ }
340+ }
341+ }
342+
198343// 辅助函数
199344func containsString (s , substr string ) bool {
200345 return len (s ) >= len (substr ) && (s == substr || len (s ) > len (substr ) && findSubstring (s , substr ))
@@ -215,4 +360,3 @@ func abs(x float32) float32 {
215360 }
216361 return x
217362}
218-
0 commit comments