@@ -26,7 +26,7 @@ pub struct EmbeddingParameters {
26
26
}
27
27
28
28
#[ tracing:: instrument]
29
- pub async fn create_embedding (
29
+ pub async fn get_dense_vector (
30
30
message : String ,
31
31
distance_phrase : Option < DistancePhrase > ,
32
32
embed_type : & str ,
@@ -35,11 +35,11 @@ pub async fn create_embedding(
35
35
let parent_span = sentry:: configure_scope ( |scope| scope. get_span ( ) ) ;
36
36
let transaction: sentry:: TransactionOrSpan = match & parent_span {
37
37
Some ( parent) => parent
38
- . start_child ( "create_embedding " , "Create semantic dense embedding" )
38
+ . start_child ( "get_dense_vector " , "Create semantic dense embedding" )
39
39
. into ( ) ,
40
40
None => {
41
41
let ctx = sentry:: TransactionContext :: new (
42
- "create_embedding " ,
42
+ "get_dense_vector " ,
43
43
"Create semantic dense embedding" ,
44
44
) ;
45
45
sentry:: start_transaction ( ctx) . into ( )
@@ -266,20 +266,20 @@ pub async fn get_sparse_vector(
266
266
}
267
267
268
268
#[ tracing:: instrument]
269
- pub async fn create_embeddings (
270
- content_and_boosts : Vec < ( String , Option < DistancePhrase > ) > ,
269
+ pub async fn get_dense_vectors (
270
+ content_and_distances : Vec < ( String , Option < DistancePhrase > ) > ,
271
271
embed_type : & str ,
272
272
dataset_config : DatasetConfiguration ,
273
273
reqwest_client : reqwest:: Client ,
274
274
) -> Result < Vec < Vec < f32 > > , ServiceError > {
275
275
let parent_span = sentry:: configure_scope ( |scope| scope. get_span ( ) ) ;
276
276
let transaction: sentry:: TransactionOrSpan = match & parent_span {
277
277
Some ( parent) => parent
278
- . start_child ( "create_embedding " , "Create semantic dense embedding" )
278
+ . start_child ( "get_dense_vector " , "Create semantic dense embedding" )
279
279
. into ( ) ,
280
280
None => {
281
281
let ctx = sentry:: TransactionContext :: new (
282
- "create_embedding " ,
282
+ "get_dense_vector " ,
283
283
"Create semantic dense embedding" ,
284
284
) ;
285
285
sentry:: start_transaction ( ctx) . into ( )
@@ -323,31 +323,31 @@ pub async fn create_embeddings(
323
323
embedding_api_key. to_string ( )
324
324
} ;
325
325
326
- let thirty_message_groups = content_and_boosts. chunks ( 30 ) ;
326
+ let ( contents, distance_phrases) : ( Vec < _ > , Vec < _ > ) =
327
+ content_and_distances. clone ( ) . into_iter ( ) . unzip ( ) ;
328
+ let thirty_content_groups = contents. chunks ( 30 ) ;
327
329
328
- let vec_futures: Vec < _ > = thirty_message_groups
330
+ let filtered_distances_with_index = distance_phrases
331
+ . clone ( )
332
+ . iter ( )
329
333
. enumerate ( )
330
- . map ( |( i, combined_messages) | {
331
- let messages = combined_messages
332
- . iter ( )
333
- . map ( |( x, _) | x)
334
- . cloned ( )
335
- . collect :: < Vec < String > > ( ) ;
336
-
337
- let boost_phrase_and_index = combined_messages
338
- . iter ( )
339
- . enumerate ( )
340
- . filter_map ( |( i, ( _, y) ) | y. clone ( ) . map ( |phrase| ( i, phrase) ) )
341
- . collect :: < Vec < ( usize , DistancePhrase ) > > ( ) ;
334
+ . filter_map ( |( index, distance_phrase) | {
335
+ distance_phrase
336
+ . clone ( )
337
+ . map ( |distance_phrase| ( index, distance_phrase) )
338
+ } )
339
+ . collect :: < Vec < ( usize , DistancePhrase ) > > ( ) ;
340
+ let thirty_filterted_distances_with_indices = filtered_distances_with_index. chunks ( 30 ) ;
342
341
343
- let boost_phrases = combined_messages
342
+ let vec_distance_futures: Vec < _ > = thirty_filterted_distances_with_indices
343
+ . map ( |thirty_distances| {
344
+ let distance_phrases = thirty_distances
344
345
. iter ( )
345
- . filter_map ( |( _, y ) | y . clone ( ) . map ( |x| x . phrase . clone ( ) ) )
346
+ . map ( |( _, x ) | x . phrase . clone ( ) )
346
347
. collect :: < Vec < String > > ( ) ;
347
348
348
- let clipped_messages = messages
349
+ let clipped_messages = distance_phrases
349
350
. iter ( )
350
- . chain ( boost_phrases. iter ( ) )
351
351
. map ( |message| {
352
352
if message. len ( ) > 5000 {
353
353
message. chars ( ) . take ( 12000 ) . collect ( )
@@ -406,65 +406,162 @@ pub async fn create_embeddings(
406
406
)
407
407
} ) ?;
408
408
409
- let mut vectors: Vec < Vec < f32 > > = embeddings
410
- . data
411
- . into_iter ( )
412
- . map ( |x| match x. embedding {
413
- EmbeddingOutput :: Float ( v) => v. iter ( ) . map ( |x| * x as f32 ) . collect ( ) ,
414
- EmbeddingOutput :: Base64 ( _) => {
415
- log:: error!( "Embedding server responded with Base64 and that is not currently supported for embeddings" ) ;
416
- vec ! [ ]
409
+ let vectors_and_boosts: Vec < ( Vec < f32 > , & ( usize , DistancePhrase ) ) > = embeddings
410
+ . data
411
+ . into_iter ( )
412
+ . map ( |x| match x. embedding {
413
+ EmbeddingOutput :: Float ( v) => v. iter ( ) . map ( |x| * x as f32 ) . collect ( ) ,
414
+ EmbeddingOutput :: Base64 ( _) => {
415
+ log:: error!( "Embedding server responded with Base64 and that is not currently supported for embeddings" ) ;
416
+ vec ! [ ]
417
+ }
418
+ } )
419
+ . zip ( thirty_distances)
420
+ . collect ( ) ;
421
+
422
+ if vectors_and_boosts. iter ( ) . any ( |x| x. 0 . is_empty ( ) ) {
423
+ return Err ( ServiceError :: InternalServerError (
424
+ "Embedding server responded with Base64 and that is not currently supported for embeddings" . to_owned ( ) ,
425
+ ) ) ;
417
426
}
418
- } )
419
- . collect ( ) ;
420
427
421
- if vectors. iter ( ) . any ( |x| x. is_empty ( ) ) {
422
- return Err ( ServiceError :: InternalServerError (
423
- "Embedding server responded with Base64 and that is not currently supported for embeddings" . to_owned ( ) ,
424
- ) ) ;
425
- }
428
+ Ok ( vectors_and_boosts)
429
+ } ;
426
430
427
- if !boost_phrase_and_index . is_empty ( ) {
428
- let boost_vectors = vectors
429
- . split_off ( messages . len ( ) ) . to_vec ( ) ;
431
+ vectors_resp
432
+ } )
433
+ . collect ( ) ;
430
434
431
- let mut vectors_sorted = vectors. clone ( ) ;
432
- for ( ( og_index, phrase) , boost_vector) in boost_phrase_and_index. iter ( ) . zip ( boost_vectors) {
433
- vectors_sorted[ * og_index] = vectors_sorted[ * og_index]
434
- . iter ( )
435
- . zip ( boost_vector)
436
- . map ( |( vector_elem, boost_vec_elem) | vector_elem + phrase. distance_factor * boost_vec_elem)
437
- . collect ( ) ;
438
- }
435
+ let vec_content_futures: Vec < _ > = thirty_content_groups
436
+ . map ( |messages| {
437
+ let clipped_messages = messages
438
+ . iter ( )
439
+ . map ( |message| {
440
+ if message. len ( ) > 5000 {
441
+ message. chars ( ) . take ( 12000 ) . collect ( )
442
+ } else {
443
+ message. clone ( )
444
+ }
445
+ } )
446
+ . collect :: < Vec < String > > ( ) ;
439
447
440
- return Ok ( ( i, vectors_sorted) ) ;
441
- }
448
+ let input = match embed_type {
449
+ "doc" => EmbeddingInput :: StringArray ( clipped_messages) ,
450
+ "query" => EmbeddingInput :: String (
451
+ format ! (
452
+ "{}{}" ,
453
+ dataset_config. EMBEDDING_QUERY_PREFIX , & clipped_messages[ 0 ]
454
+ )
455
+ . to_string ( ) ,
456
+ ) ,
457
+ _ => EmbeddingInput :: StringArray ( clipped_messages) ,
458
+ } ;
442
459
443
- Ok ( ( i, vectors) )
460
+ let parameters = EmbeddingParameters {
461
+ model : dataset_config. EMBEDDING_MODEL_NAME . to_string ( ) ,
462
+ input,
463
+ truncate : true ,
444
464
} ;
445
465
446
- vectors_resp
447
- } )
466
+ let cur_client = reqwest_client. clone ( ) ;
467
+ let url = embedding_base_url. clone ( ) ;
468
+
469
+ let embedding_api_key = embedding_api_key. clone ( ) ;
470
+
471
+ let vectors_resp = async move {
472
+ let embeddings_resp = cur_client
473
+ . post ( & format ! ( "{}/embeddings?api-version=2023-05-15" , url) )
474
+ . header ( "Authorization" , & format ! ( "Bearer {}" , & embedding_api_key. clone( ) ) )
475
+ . header ( "api-key" , & embedding_api_key. clone ( ) )
476
+ . header ( "Content-Type" , "application/json" )
477
+ . json ( & parameters)
478
+ . send ( )
479
+ . await
480
+ . map_err ( |_| {
481
+ ServiceError :: BadRequest ( "Failed to send message to embedding server" . to_string ( ) )
482
+ } ) ?
483
+ . text ( )
484
+ . await
485
+ . map_err ( |_| {
486
+ ServiceError :: BadRequest ( "Failed to get text from embeddings" . to_string ( ) )
487
+ } ) ?;
488
+
489
+ let embeddings: EmbeddingResponse = format_response ( embeddings_resp. clone ( ) )
490
+ . map_err ( move |_e| {
491
+ log:: error!( "Failed to format response from embeddings server {:?}" , embeddings_resp) ;
492
+ ServiceError :: InternalServerError (
493
+ format ! ( "Failed to format response from embeddings server {:?}" , embeddings_resp)
494
+ )
495
+ } ) ?;
496
+
497
+ let vectors: Vec < Vec < f32 > > = embeddings
498
+ . data
499
+ . into_iter ( )
500
+ . map ( |x| match x. embedding {
501
+ EmbeddingOutput :: Float ( v) => v. iter ( ) . map ( |x| * x as f32 ) . collect ( ) ,
502
+ EmbeddingOutput :: Base64 ( _) => {
503
+ log:: error!( "Embedding server responded with Base64 and that is not currently supported for embeddings" ) ;
504
+ vec ! [ ]
505
+ }
506
+ } )
507
+ . collect ( ) ;
508
+
509
+ if vectors. iter ( ) . any ( |x| x. is_empty ( ) ) {
510
+ return Err ( ServiceError :: InternalServerError (
511
+ "Embedding server responded with Base64 and that is not currently supported for embeddings" . to_owned ( ) ,
512
+ ) ) ;
513
+ }
514
+ Ok ( vectors)
515
+ } ;
516
+
517
+ vectors_resp
518
+
519
+ } )
448
520
. collect ( ) ;
449
521
450
- let all_chunk_vectors : Vec < ( usize , Vec < Vec < f32 > > ) > = futures:: future:: join_all ( vec_futures )
522
+ let mut content_vectors : Vec < _ > = futures:: future:: join_all ( vec_content_futures )
451
523
. await
452
524
. into_iter ( )
453
- . collect :: < Result < Vec < ( usize , Vec < Vec < f32 > > ) > , ServiceError > > ( ) ?;
525
+ . collect :: < Result < Vec < _ > , ServiceError > > ( ) ?
526
+ . into_iter ( )
527
+ . flatten ( )
528
+ . collect ( ) ;
454
529
455
- let mut vectors_sorted = vec ! [ ] ;
456
- for index in 0 ..all_chunk_vectors . len ( ) {
457
- let ( _ , vectors_i ) = all_chunk_vectors . iter ( ) . find ( | ( i , _ ) | * i == index ) . ok_or (
458
- ServiceError :: InternalServerError (
459
- "Failed to get index i (this should never happen)" . to_string ( ) ,
460
- ) ,
461
- ) ? ;
530
+ let distance_vectors : Vec < _ > = futures :: future :: join_all ( vec_distance_futures )
531
+ . await
532
+ . into_iter ( )
533
+ . collect :: < Result < Vec < _ > , ServiceError > > ( ) ?
534
+ . into_iter ( )
535
+ . flatten ( )
536
+ . collect ( ) ;
462
537
463
- vectors_sorted. extend ( vectors_i. clone ( ) ) ;
538
+ if !distance_vectors. is_empty ( ) {
539
+ content_vectors = content_vectors
540
+ . into_iter ( )
541
+ . enumerate ( )
542
+ . map ( |( i, message) | {
543
+ let distance_vector = distance_vectors
544
+ . iter ( )
545
+ . find ( |( _, ( og_index, _) ) | * og_index == i) ;
546
+ match distance_vector {
547
+ Some ( ( distance_vec, ( _, distance_phrase) ) ) => {
548
+ let distance_factor = distance_phrase. distance_factor ;
549
+ message
550
+ . iter ( )
551
+ . zip ( distance_vec)
552
+ . map ( |( vec_elem, distance_elem) | {
553
+ vec_elem + distance_factor * distance_elem
554
+ } )
555
+ . collect ( )
556
+ }
557
+ None => message,
558
+ }
559
+ } )
560
+ . collect ( ) ;
464
561
}
465
562
466
563
transaction. finish ( ) ;
467
- Ok ( vectors_sorted )
564
+ Ok ( content_vectors )
468
565
}
469
566
470
567
#[ derive( Debug , Serialize , Deserialize ) ]
@@ -492,31 +589,31 @@ pub struct CustomSparseEmbedData {
492
589
493
590
#[ tracing:: instrument]
494
591
pub async fn get_sparse_vectors (
495
- messages : Vec < ( String , Option < BoostPhrase > ) > ,
592
+ content_and_boosts : Vec < ( String , Option < BoostPhrase > ) > ,
496
593
embed_type : & str ,
497
594
reqwest_client : reqwest:: Client ,
498
595
) -> Result < Vec < Vec < ( u32 , f32 ) > > , ServiceError > {
499
- if messages . is_empty ( ) {
596
+ if content_and_boosts . is_empty ( ) {
500
597
return Err ( ServiceError :: BadRequest (
501
598
"No messages to encode" . to_string ( ) ,
502
599
) ) ;
503
600
}
504
601
505
- let contents = messages
602
+ let contents = content_and_boosts
506
603
. clone ( )
507
604
. into_iter ( )
508
605
. map ( |( x, _) | x)
509
606
. collect :: < Vec < String > > ( ) ;
510
607
let thirty_content_groups = contents. chunks ( 30 ) ;
511
608
512
- let filtered_boosts_with_index = messages
609
+ let filtered_boosts_with_index = content_and_boosts
513
610
. into_iter ( )
514
611
. enumerate ( )
515
612
. filter_map ( |( i, ( _, y) ) | y. map ( |boost_phrase| ( i, boost_phrase) ) )
516
613
. collect :: < Vec < ( usize , BoostPhrase ) > > ( ) ;
517
- let filtered_boosts_with_index_groups = filtered_boosts_with_index. chunks ( 30 ) ;
614
+ let thirty_filtered_boosts_with_indices = filtered_boosts_with_index. chunks ( 30 ) ;
518
615
519
- let vec_boost_futures: Vec < _ > = filtered_boosts_with_index_groups
616
+ let vec_boost_futures: Vec < _ > = thirty_filtered_boosts_with_indices
520
617
. enumerate ( )
521
618
. map ( |( i, thirty_boosts) | {
522
619
let cur_client = reqwest_client. clone ( ) ;
0 commit comments