Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Speed up method AstraDBVectorStoreComponent._initialize_collection_options by 24% in PR #6236 (LFOSS-492) #6635

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 34 additions & 44 deletions src/backend/base/langflow/components/vectorstores/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,42 +442,33 @@

def get_keyspace(self):
keyspace = self.keyspace

if keyspace:
return keyspace.strip()

return None

def get_database_object(self, api_endpoint: str | None = None):
if self.database_cache:
return self.database_cache
try:
client = DataAPIClient(token=self.token, environment=self.environment)

return client.get_database(
self.database_cache = client.get_database(
api_endpoint=api_endpoint or self.get_api_endpoint(),
token=self.token,
keyspace=self.get_keyspace(),
)
return self.database_cache

Check failure on line 459 in src/backend/base/langflow/components/vectorstores/astradb.py

View workflow job for this annotation

GitHub Actions / Ruff Style Check (3.12)

Ruff (TRY300)

src/backend/base/langflow/components/vectorstores/astradb.py:459:13: TRY300 Consider moving this statement to an `else` block
except Exception as e:
msg = f"Error fetching database object: {e}"
raise ValueError(msg) from e

def collection_data(self, collection_name: str, database: Database | None = None):
try:
if not database:
client = DataAPIClient(token=self.token, environment=self.environment)

database = client.get_database(
api_endpoint=self.get_api_endpoint(),
token=self.token,
keyspace=self.get_keyspace(),
)

database = self.get_database_object()
collection = database.get_collection(collection_name, keyspace=self.get_keyspace())

return collection.estimated_document_count()
except Exception as e: # noqa: BLE001
except Exception as e:

Check failure on line 470 in src/backend/base/langflow/components/vectorstores/astradb.py

View workflow job for this annotation

GitHub Actions / Ruff Style Check (3.12)

Ruff (BLE001)

src/backend/base/langflow/components/vectorstores/astradb.py:470:16: BLE001 Do not catch blind exception: `Exception`
self.log(f"Error checking collection data: {e}")

return None

def _initialize_database_options(self):
Expand All @@ -497,40 +488,12 @@
raise ValueError(msg) from e

def _initialize_collection_options(self, api_endpoint: str | None = None):
# Nothing to generate if we don't have an API endpoint yet
api_endpoint = api_endpoint or self.get_api_endpoint()
if not api_endpoint:
return []

# Retrieve the database object
database = self.get_database_object(api_endpoint=api_endpoint)

# Get the list of collections
collection_list = list(database.list_collections(keyspace=self.get_keyspace()))

# Return the list of collections and metadata associated
return [
{
"name": col.name,
"records": self.collection_data(collection_name=col.name, database=database),
"provider": (
col.options.vector.service.provider if col.options.vector and col.options.vector.service else None
),
"icon": (
"vectorstores"
if not col.options.vector or not col.options.vector.service
else "NVIDIA"
if col.options.vector.service.provider == "nvidia"
else "OpenAI"
if col.options.vector.service.provider == "openai" # TODO: Add more icons
else col.options.vector.service.provider.title()
),
"model": (
col.options.vector.service.model_name if col.options.vector and col.options.vector.service else None
),
}
for col in collection_list
]
return [self._get_collection_metadata(col, database) for col in collection_list]

def reset_provider_options(self, build_config: dict):
# Get the list of vectorize providers
Expand Down Expand Up @@ -946,3 +909,30 @@
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}

def __init__(self):
super().__init__()
self.database_cache = None

def _get_collection_metadata(self, collection, database):
options_vector_service = collection.options.vector.service if collection.options.vector else None
service_provider = options_vector_service.provider if options_vector_service else None
service_model = options_vector_service.model_name if options_vector_service else None
icon = self._determine_icon(service_provider)
return {
"name": collection.name,
"records": self.collection_data(collection_name=collection.name, database=database),
"provider": service_provider,
"icon": icon,
"model": service_model,
}

def _determine_icon(self, provider):
if provider is None:
return "vectorstores"
provider = provider.lower()
if provider == "nvidia":
return "NVIDIA"
if provider == "openai":
return "OpenAI"
return provider.title()
Loading