@@ -94,6 +94,7 @@ def __init__(
94
94
self ._setup_db_interceptor (kwargs .get ("db_name" ))
95
95
self ._setup_grpc_channel ()
96
96
self .callbacks = []
97
+ self .schema_cache = {}
97
98
98
99
def register_state_change_callback (self , callback : Callable ):
99
100
self .callbacks .append (callback )
@@ -161,6 +162,7 @@ def close(self):
161
162
self ._channel .close ()
162
163
163
164
def reset_db_name (self , db_name : str ):
165
+ self .schema_cache .clear ()
164
166
self ._setup_db_interceptor (db_name )
165
167
self ._setup_grpc_channel ()
166
168
self ._setup_identifier_interceptor (self ._user )
@@ -526,10 +528,28 @@ def insert_rows(
526
528
collection_name , entities , partition_name , schema , timeout , ** kwargs
527
529
)
528
530
resp = self ._stub .Insert (request = request , timeout = timeout )
531
+ if resp .status .error_code == common_pb2 .SchemaMismatch :
532
+ schema = self .update_schema (collection_name , timeout )
533
+ request = self ._prepare_row_insert_request (
534
+ collection_name , entities , partition_name , schema , timeout , ** kwargs
535
+ )
536
+ resp = self ._stub .Insert (request = request , timeout = timeout )
529
537
check_status (resp .status )
530
538
ts_utils .update_collection_ts (collection_name , resp .timestamp )
531
539
return MutationResult (resp )
532
540
541
+ def update_schema (self , collection_name : str , timeout : Optional [float ] = None ):
542
+ self .schema_cache .pop (collection_name , None )
543
+ schema = self .describe_collection (collection_name , timeout = timeout )
544
+ schema_timestamp = schema .get ("update_timestamp" , 0 )
545
+
546
+ self .schema_cache [collection_name ] = {
547
+ "schema" : schema ,
548
+ "schema_timestamp" : schema_timestamp ,
549
+ }
550
+
551
+ return schema
552
+
533
553
def _prepare_row_insert_request (
534
554
self ,
535
555
collection_name : str ,
@@ -542,9 +562,9 @@ def _prepare_row_insert_request(
542
562
if isinstance (entity_rows , dict ):
543
563
entity_rows = [entity_rows ]
544
564
545
- if not isinstance ( schema , dict ):
546
- schema = self . describe_collection ( collection_name , timeout = timeout )
547
-
565
+ schema , schema_timestamp = self . _get_schema_from_cache_or_remote (
566
+ collection_name , schema , timeout
567
+ )
548
568
fields_info = schema .get ("fields" )
549
569
enable_dynamic = schema .get ("enable_dynamic_field" , False )
550
570
@@ -554,8 +574,33 @@ def _prepare_row_insert_request(
554
574
partition_name ,
555
575
fields_info ,
556
576
enable_dynamic = enable_dynamic ,
577
+ schema_timestamp = schema_timestamp ,
557
578
)
558
579
580
+ def _get_schema_from_cache_or_remote (
581
+ self , collection_name : str , schema : Optional [dict ] = None , timeout : Optional [float ] = None
582
+ ):
583
+ """
584
+ checks the cache for the schema. If not found, it fetches it remotely and updates the cache
585
+ """
586
+ if collection_name in self .schema_cache :
587
+ # Use the cached schema and timestamp
588
+ schema = self .schema_cache [collection_name ]["schema" ]
589
+ schema_timestamp = self .schema_cache [collection_name ]["schema_timestamp" ]
590
+ else :
591
+ # Fetch the schema remotely if not in cache
592
+ if not isinstance (schema , dict ):
593
+ schema = self .describe_collection (collection_name , timeout = timeout )
594
+ schema_timestamp = schema .get ("update_timestamp" , 0 )
595
+
596
+ # Cache the fetched schema and timestamp
597
+ self .schema_cache [collection_name ] = {
598
+ "schema" : schema ,
599
+ "schema_timestamp" : schema_timestamp ,
600
+ }
601
+
602
+ return schema , schema_timestamp
603
+
559
604
def _prepare_batch_insert_request (
560
605
self ,
561
606
collection_name : str ,
@@ -723,13 +768,18 @@ def _prepare_row_upsert_request(
723
768
if not isinstance (rows , list ):
724
769
raise ParamError (message = "'rows' must be a list, please provide valid row data." )
725
770
726
- fields_info , enable_dynamic = self ._get_info (collection_name , timeout , ** kwargs )
771
+ schema , schema_timestamp = self ._get_schema_from_cache_or_remote (
772
+ collection_name , timeout = timeout
773
+ )
774
+ fields_info = schema .get ("fields" )
775
+ enable_dynamic = schema .get ("enable_dynamic_field" , False )
727
776
return Prepare .row_upsert_param (
728
777
collection_name ,
729
778
rows ,
730
779
partition_name ,
731
780
fields_info ,
732
781
enable_dynamic = enable_dynamic ,
782
+ schema_timestamp = schema_timestamp ,
733
783
)
734
784
735
785
@retry_on_rpc_failure ()
@@ -748,6 +798,12 @@ def upsert_rows(
748
798
)
749
799
rf = self ._stub .Upsert .future (request , timeout = timeout )
750
800
response = rf .result ()
801
+ if response .status .error_code == common_pb2 .SchemaMismatch :
802
+ schema = self .update_schema (collection_name , timeout )
803
+ request = self ._prepare_row_insert_request (
804
+ collection_name , entities , partition_name , schema , timeout , ** kwargs
805
+ )
806
+ response = self ._stub .Insert (request = request , timeout = timeout )
751
807
check_status (response .status )
752
808
m = MutationResult (response )
753
809
ts_utils .update_collection_ts (collection_name , m .timestamp )
0 commit comments