diff --git a/src/backend/base/langflow/components/vectorstores/astradb.py b/src/backend/base/langflow/components/vectorstores/astradb.py index 4c124cefd808..a28178a9cfd4 100644 --- a/src/backend/base/langflow/components/vectorstores/astradb.py +++ b/src/backend/base/langflow/components/vectorstores/astradb.py @@ -1,4 +1,3 @@ -import os from collections import defaultdict from dataclasses import asdict, dataclass, field @@ -286,23 +285,21 @@ def create_collection_api( embedding_generation_provider: str | None = None, embedding_generation_model: str | None = None, ): - # Create the data API client + # Initialize the data API client once client = DataAPIClient(token=token) - # Get the database object - database = client.get_database(api_endpoint=api_endpoint, token=token) + # Get the database directly using the initialized client + database = client.get_database(api_endpoint=api_endpoint) - # Build vectorize options, if needed + # Set vectorize options only if dimension is not specified vectorize_options = None - if not dimension: + if dimension is None: vectorize_options = CollectionVectorServiceOptions( provider=embedding_generation_provider, model_name=embedding_generation_model, - authentication=None, - parameters=None, ) - # Create the collection + # Return the created collection return database.create_collection( name=new_collection_name, dimension=dimension, @@ -375,7 +372,7 @@ def get_api_endpoint(self): token=self.token, environment=self.environment, api_endpoint=self.api_endpoint, - database_name=self.database_name + database_name=self.database_name, ) def get_keyspace(self): @@ -493,8 +490,7 @@ def reset_collection_list(self, build_config: dict): ] # Reset the selected collection - if build_config["collection_name"]["value"] not in build_config["collection_name"]["options"]: - build_config["collection_name"]["value"] = "" + build_config["collection_name"]["value"] = "" return build_config @@ -509,9 +505,8 @@ def reset_database_list(self, build_config: dict): ] # Reset the selected database - if build_config["database_name"]["value"] not in build_config["database_name"]["options"]: - build_config["database_name"]["value"] = "" - build_config["api_endpoint"]["value"] = "" + build_config["database_name"]["value"] = "" + build_config["api_endpoint"]["value"] = "" return build_config @@ -571,6 +566,7 @@ def update_build_config(self, build_config: dict, field_value: str, field_name: # If this is the first execution of the component, reset and build database list if first_run or field_name in ["token", "environment"]: # Reset the build config to ensure we are starting fresh + build_config = self.reset_build_config(build_config) build_config = self.reset_database_list(build_config) # Get list of regions for a given cloud provider @@ -626,9 +622,9 @@ def update_build_config(self, build_config: dict, field_value: str, field_name: "embedding_generation_model" ]["options"] if not model_options: - embedding_provider = build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][ - "embedding_generation_provider" - ]["value"] + embedding_provider = build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"][ + "template" + ]["embedding_generation_provider"]["value"] build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][ "embedding_generation_model"