|
59 | 59 | RagFile, |
60 | 60 | RagManagedDb, |
61 | 61 | RagManagedDbConfig, |
| 62 | + RagManagedVertexVectorSearch, |
62 | 63 | RagVectorDbConfig, |
63 | 64 | Basic, |
64 | 65 | Enterprise, |
@@ -176,6 +177,15 @@ def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool |
176 | 177 | return gapic_vector_db.vertex_vector_search.ByteSize() > 0 |
177 | 178 |
|
178 | 179 |
|
| 180 | +def _check_rag_managed_vertex_vector_search( |
| 181 | + gapic_vector_db: GapicRagVectorDbConfig, |
| 182 | +) -> bool: |
| 183 | + try: |
| 184 | + return gapic_vector_db.__contains__("rag_managed_vertex_vector_search") |
| 185 | + except AttributeError: |
| 186 | + return gapic_vector_db.rag_managed_vertex_vector_search.ByteSize() > 0 |
| 187 | + |
| 188 | + |
179 | 189 | def _check_rag_embedding_model_config( |
180 | 190 | gapic_vector_db: GapicRagVectorDbConfig, |
181 | 191 | ) -> bool: |
@@ -240,6 +250,10 @@ def convert_gapic_to_vector_db( |
240 | 250 | index_name=gapic_vector_db.pinecone.index_name, |
241 | 251 | api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, |
242 | 252 | ) |
| 253 | + elif _check_rag_managed_vertex_vector_search(gapic_vector_db): |
| 254 | + return RagManagedVertexVectorSearch( |
| 255 | + collection_name=gapic_vector_db.rag_managed_vertex_vector_search.collection_name, |
| 256 | + ) |
243 | 257 | elif _check_vertex_vector_search(gapic_vector_db): |
244 | 258 | return VertexVectorSearch( |
245 | 259 | index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint, |
@@ -299,6 +313,10 @@ def convert_gapic_to_backend_config( |
299 | 313 | index_name=gapic_vector_db.pinecone.index_name, |
300 | 314 | api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, |
301 | 315 | ) |
| 316 | + elif _check_rag_managed_vertex_vector_search(gapic_vector_db): |
| 317 | + vector_config.vector_db = RagManagedVertexVectorSearch( |
| 318 | + collection_name=gapic_vector_db.rag_managed_vertex_vector_search.collection_name, |
| 319 | + ) |
302 | 320 | elif _check_vertex_vector_search(gapic_vector_db): |
303 | 321 | vector_config.vector_db = VertexVectorSearch( |
304 | 322 | index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint, |
@@ -904,9 +922,14 @@ def set_vector_db( |
904 | 922 | ), |
905 | 923 | ), |
906 | 924 | ) |
| 925 | + elif isinstance(vector_db, RagManagedVertexVectorSearch): |
| 926 | + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( |
| 927 | + rag_managed_vertex_vector_search=GapicRagVectorDbConfig.RagManagedVertexVectorSearch(), |
| 928 | + ) |
| 929 | + |
907 | 930 | else: |
908 | 931 | raise TypeError( |
909 | | - "vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone." |
| 932 | + "vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, Pinecone, or RagManagedVertexVectorSearch." |
910 | 933 | ) |
911 | 934 |
|
912 | 935 |
|
@@ -973,10 +996,14 @@ def set_backend_config( |
973 | 996 | rag_corpus.vector_db_config.api_auth.api_key_config.api_key_secret_version = ( |
974 | 997 | api_key |
975 | 998 | ) |
| 999 | + elif isinstance(vector_config, RagManagedVertexVectorSearch): |
| 1000 | + rag_corpus.vector_db_config.rag_managed_vertex_vector_search.CopyFrom( |
| 1001 | + GapicRagVectorDbConfig.RagManagedVertexVectorSearch() |
| 1002 | + ) |
976 | 1003 | else: |
977 | 1004 | raise TypeError( |
978 | 1005 | "backend_config must be a VertexFeatureStore," |
979 | | - "RagManagedDb, or Pinecone." |
| 1006 | + "RagManagedDb, Pinecone, or RagManagedVertexVectorSearch." |
980 | 1007 | ) |
981 | 1008 | if backend_config.rag_embedding_model_config: |
982 | 1009 | set_embedding_model_config( |
|
0 commit comments