Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Speed up method AstraDBVectorStoreComponent._initialize_collection_options by 24% in PR #6236 (LFOSS-492) #6635

Closed

Conversation

codeflash-ai[bot]
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Feb 14, 2025

⚡️ This pull request contains optimizations for PR #6236

If you approve this dependent PR, these changes will be merged into the original PR branch LFOSS-492.

This PR will be automatically closed if the original PR is merged.


📄 24% (0.24x) speedup for AstraDBVectorStoreComponent._initialize_collection_options in src/backend/base/langflow/components/vectorstores/astradb.py

⏱️ Runtime : 596 milliseconds 481 milliseconds (best of 5 runs)

📝 Explanation and details

Certainly! Below is the optimized code.

Changes Made.

  1. Cached the database object to avoid multiple calls to the get_database_object method.
  2. Moved repeated calls to client.get_database to a single call within get_database_object.
  3. Reduced redundant fetch operations from API by reusing existing database objects when available.
  4. Removed redundant dictionary generation within list comprehension (used a helper function for clarity).

Explanation.

  1. Caching Database Object: The database_cache attribute avoids redundant calls to the database API.
  2. New __init__ Method: Initializes database_cache as None.
  3. Refactored _initialize_collection_options and added _get_collection_metadata: Simplifies list comprehension and centralizes metadata creation.
  4. New _determine_icon Method: Method for determining the icon based on the provider. This reduces complexity in the metadata initialization.

The provided code should run faster and reduces unnecessary repeated operations. Each method maintains the same functionality as before.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 13 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage undefined
🌀 Generated Regression Tests Details
from unittest.mock import MagicMock, patch

# imports
import pytest  # used for our unit tests
# function to test
from astrapy import DataAPIClient, Database
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.components.vectorstores.astradb import \
    AstraDBVectorStoreComponent


# unit tests
@pytest.fixture
def astra_component():
    component = AstraDBVectorStoreComponent()
    component.token = "test_token"
    component.environment = "test_env"
    component.api_endpoint = "test_endpoint"
    component.database_name = "test_db"
    component.keyspace = "test_keyspace"
    component.log = MagicMock()
    return component

def test_valid_api_endpoint_with_collections(astra_component):
    mock_database = MagicMock()
    mock_collection = MagicMock()
    mock_collection.name = "col1"
    mock_collection.options.vector.service.provider = "nvidia"
    mock_collection.options.vector.service.model_name = "model1"
    mock_database.list_collections.return_value = [mock_collection]
    mock_database.get_collection.return_value.estimated_document_count.return_value = 10

    with patch.object(astra_component, 'get_database_object', return_value=mock_database):
        codeflash_output = astra_component._initialize_collection_options(api_endpoint="valid_endpoint")

def test_valid_api_endpoint_with_no_collections(astra_component):
    mock_database = MagicMock()
    mock_database.list_collections.return_value = []

    with patch.object(astra_component, 'get_database_object', return_value=mock_database):
        codeflash_output = astra_component._initialize_collection_options(api_endpoint="valid_endpoint")

def test_no_api_endpoint_provided(astra_component):
    with patch.object(astra_component, 'get_api_endpoint', return_value=None):
        codeflash_output = astra_component._initialize_collection_options(api_endpoint=None)


def test_keyspace_is_none(astra_component):
    astra_component.keyspace = None
    mock_database = MagicMock()
    mock_database.list_collections.return_value = []

    with patch.object(astra_component, 'get_database_object', return_value=mock_database):
        codeflash_output = astra_component._initialize_collection_options(api_endpoint="valid_endpoint")

def test_database_object_retrieval_failure(astra_component):
    with patch.object(astra_component, 'get_database_object', side_effect=ValueError("Database error")):
        with pytest.raises(ValueError, match="Database error"):
            astra_component._initialize_collection_options(api_endpoint="valid_endpoint")


def test_collections_with_no_vector_service(astra_component):
    mock_database = MagicMock()
    mock_collection = MagicMock()
    mock_collection.name = "col1"
    mock_collection.options.vector = None
    mock_database.list_collections.return_value = [mock_collection]
    mock_database.get_collection.return_value.estimated_document_count.return_value = 10

    with patch.object(astra_component, 'get_database_object', return_value=mock_database):
        codeflash_output = astra_component._initialize_collection_options(api_endpoint="valid_endpoint")


def test_collections_with_missing_model_names(astra_component):
    mock_database = MagicMock()
    mock_collection = MagicMock()
    mock_collection.name = "col1"
    mock_collection.options.vector.service.provider = "nvidia"
    mock_collection.options.vector.service.model_name = None
    mock_database.list_collections.return_value = [mock_collection]
    mock_database.get_collection.return_value.estimated_document_count.return_value = 10

    with patch.object(astra_component, 'get_database_object', return_value=mock_database):
        codeflash_output = astra_component._initialize_collection_options(api_endpoint="valid_endpoint")


def test_large_collection_data(astra_component):
    mock_database = MagicMock()
    mock_collection = MagicMock()
    mock_collection.name = "col1"
    mock_collection.options.vector.service.provider = "nvidia"
    mock_collection.options.vector.service.model_name = "model1"
    mock_database.list_collections.return_value = [mock_collection]
    mock_database.get_collection.return_value.estimated_document_count.return_value = 1000000

    with patch.object(astra_component, 'get_database_object', return_value=mock_database):
        codeflash_output = astra_component._initialize_collection_options(api_endpoint="valid_endpoint")
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from unittest.mock import MagicMock, patch

# imports
import pytest  # used for our unit tests
# function to test
from astrapy import DataAPIClient, Database
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.components.vectorstores.astradb import \
    AstraDBVectorStoreComponent

# unit tests

@pytest.fixture
def mock_component():
    component = AstraDBVectorStoreComponent()
    component.token = "test_token"
    component.environment = "test_env"
    component.api_endpoint = "http://test_api_endpoint"
    component.database_name = "test_db"
    component.keyspace = "test_keyspace"
    return component


def test_valid_api_endpoint_empty_collections(mock_component):
    # Mocking database with no collections
    mock_database = MagicMock()
    mock_database.list_collections.return_value = []
    mock_component.get_database_object = MagicMock(return_value=mock_database)
    
    codeflash_output = mock_component._initialize_collection_options()

def test_invalid_api_endpoint_none(mock_component):
    mock_component.get_api_endpoint = MagicMock(return_value=None)
    codeflash_output = mock_component._initialize_collection_options(api_endpoint=None)

def test_invalid_api_endpoint_empty_string(mock_component):
    mock_component.get_api_endpoint = MagicMock(return_value="")
    codeflash_output = mock_component._initialize_collection_options(api_endpoint="")

def test_database_client_initialization_failure(mock_component):
    with patch('astrapy.DataAPIClient', side_effect=Exception("Initialization error")):
        with pytest.raises(ValueError) as excinfo:
            mock_component._initialize_collection_options()

def test_keyspace_none(mock_component):
    mock_component.get_keyspace = MagicMock(return_value=None)
    mock_database = MagicMock()
    mock_database.list_collections.return_value = []
    mock_component.get_database_object = MagicMock(return_value=mock_database)
    
    codeflash_output = mock_component._initialize_collection_options()

Codeflash

…n_options` by 24% in PR #6236 (`LFOSS-492`)

Certainly! Below is the optimized code.

### Changes Made.
1. Cached the `database` object to avoid multiple calls to the `get_database_object` method.
2. Moved repeated calls to `client.get_database` to a single call within `get_database_object`.
3. Reduced redundant fetch operations from API by reusing existing database objects when available.
4. Removed redundant dictionary generation within list comprehension (used a helper function for clarity).



### Explanation.
1. **Caching Database Object**: The `database_cache` attribute avoids redundant calls to the database API.
2. **New `__init__` Method**: Initializes `database_cache` as `None`.
3. **Refactored `_initialize_collection_options` and added `_get_collection_metadata`**: Simplifies list comprehension and centralizes metadata creation.
4. **New `_determine_icon` Method**: Method for determining the icon based on the provider. This reduces complexity in the metadata initialization. 

The provided code should run faster and reduces unnecessary repeated operations. Each method maintains the same functionality as before.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Feb 14, 2025
@dosubot dosubot bot added size:M This PR changes 30-99 lines, ignoring generated files. enhancement New feature or request labels Feb 14, 2025
@erichare erichare closed this Feb 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡️ codeflash Optimization PR opened by Codeflash AI enhancement New feature or request size:M This PR changes 30-99 lines, ignoring generated files.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant