Skip to content

Commit

Permalink
PR feedback and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
leonbi100 committed Dec 20, 2024
1 parent 66fe9a7 commit 9166b5a
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from databricks_langchain.chat_models import ChatDatabricks
from databricks_langchain.embeddings import DatabricksEmbeddings
from databricks_langchain.genie import GenieAgent
from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool
from databricks_langchain.vectorstores import DatabricksVectorSearch
from databricks_langchain.vector_search import VectorSearchRetrieverTool

# Expose all integrations to users under databricks-langchain
__all__ = [
Expand Down
10 changes: 3 additions & 7 deletions integrations/langchain/src/databricks_langchain/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from typing import Any, List, Union
import json
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse

import numpy as np
from enum import Enum
import json

from typing import (
Dict,
Optional
)

def get_deployment_client(target_uri: str) -> Any:
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

from pydantic import BaseModel, Field, model_validator, PrivateAttr

from databricks_langchain import DatabricksVectorSearch
from databricks_langchain.utils import IndexDetails
from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, PrivateAttr, model_validator

from databricks_langchain.utils import IndexDetails
from databricks_langchain.vectorstores import DatabricksVectorSearch


class VectorSearchRetrieverToolInput(BaseModel):
query: str = Field(description="The string used to query the index with and identify the most similar "
"vectors and return the associated documents.")

class VectorSearchRetrieverTool(BaseTool):
"""
Expand All @@ -19,27 +23,25 @@ class VectorSearchRetrieverTool(BaseTool):
num_results: int = Field(10, description="The number of results to return.")
columns: Optional[List[str]] = Field(None, description="Columns to return when doing the search.")
filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.")
query_type: str = Field("ANN", description="The type of query to run.")
query_type: str = Field("ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'.")
tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.")
tool_description: Optional[str] = Field(None, description="A description of the tool.")
# TODO: Confirm if we can add these two to the API to support direct-access indexes or a delta-sync indexes with self-managed embeddings,
text_column: Optional[str] = Field(None, description="If using a direct-access index or delta-sync index, specify the text column.")
text_column: Optional[str] = Field(None, description="The name of the text column to use for the embeddings. "
"Required for direct-access index or delta-sync index with "
"self-managed embeddings.")
embedding: Optional[Embeddings] = Field(None, description="Embedding model for self-managed embeddings.")
# TODO: Confirm if we can add this endpoint field
endpoint: Optional[str] = Field(None, description="Endpoint for DatabricksVectorSearch.")

# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs()
name: str = Field(default="", description="The name of the tool")
description: str = Field(default="", description="The description of the tool")
args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput

_vector_store: DatabricksVectorSearch = PrivateAttr()

@model_validator(mode='after')
def validate_tool_inputs(self):
# Construct the vector store using provided params
kwargs = {
"index_name": self.index_name,
"endpoint": self.endpoint,
"embedding": self.embedding,
"text_column": self.text_column,
"columns": self.columns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore

from databricks_langchain.utils import maximal_marginal_relevance, IndexDetails
from databricks_langchain.utils import IndexDetails, maximal_marginal_relevance

logger = logging.getLogger(__name__)

Expand Down
11 changes: 6 additions & 5 deletions integrations/langchain/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Test chat model integration."""

import json
from typing import Generator
from unittest import mock

import mlflow # type: ignore # noqa: F401
import pytest
Expand All @@ -29,10 +27,13 @@
_convert_dict_to_message_chunk,
_convert_message_to_dict,
)
from tests.utils.chat_models import ( # noqa: F401
_MOCK_CHAT_RESPONSE,
_MOCK_STREAM_RESPONSE,
llm,
mock_client,
)

from databricks_langchain import ChatDatabricks

from tests.utils.chat_models import _MOCK_CHAT_RESPONSE, _MOCK_STREAM_RESPONSE, mock_client, llm

def test_dict(llm: ChatDatabricks) -> None:
d = llm.dict()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from typing import Any, Dict, Generator, List, Optional, Set
from typing import Any, Dict, List, Optional

import pytest
from databricks.vector_search.client import VectorSearchIndex # type: ignore

from databricks_langchain import VectorSearchRetrieverTool, ChatDatabricks
from tests.utils.vector_search import EMBEDDING_MODEL, DELTA_SYNC_INDEX, ALL_INDEX_NAMES, mock_vs_client, mock_workspace_client, mock_workspace_client
from tests.utils.chat_models import mock_client, llm
from langchain_core.tools import BaseTool
from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool

from databricks_langchain import ChatDatabricks, VectorSearchRetrieverTool
from tests.utils.chat_models import llm, mock_client # noqa: F401
from tests.utils.vector_search import ( # noqa: F401
ALL_INDEX_NAMES,
DELTA_SYNC_INDEX,
EMBEDDING_MODEL,
mock_vs_client,
mock_workspace_client,
)


def init_vector_search_tool(
index_name: str,
Expand Down
16 changes: 8 additions & 8 deletions integrations/langchain/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import uuid
from typing import Any, Dict, Generator, List, Optional, Set
from typing import Any, Dict, List, Optional, Set
from unittest.mock import MagicMock, patch

import pytest
from databricks.vector_search.client import VectorSearchIndex # type: ignore

from databricks_langchain.vectorstores import DatabricksVectorSearch

from tests.utils.vector_search import (
INPUT_TEXTS,
FakeEmbeddings,
ALL_INDEX_NAMES,
DELTA_SYNC_INDEX,
DIRECT_ACCESS_INDEX,
EMBEDDING_MODEL,
ENDPOINT_NAME,
DIRECT_ACCESS_INDEX,
DELTA_SYNC_INDEX,
ALL_INDEX_NAMES,
INDEX_DETAILS,
mock_vs_client,
INPUT_TEXTS,
FakeEmbeddings,
mock_vs_client, # noqa: F401
)


def init_vector_search(
index_name: str, columns: Optional[List[str]] = None
) -> DatabricksVectorSearch:
Expand Down
5 changes: 3 additions & 2 deletions integrations/langchain/tests/utils/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from typing import Generator
from unittest import mock

import pytest

from databricks_langchain import ChatDatabricks

_MOCK_CHAT_RESPONSE = {
Expand Down
2 changes: 1 addition & 1 deletion integrations/langchain/tests/utils/vector_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Any, Dict, Generator, List, Optional, Set
from typing import Generator, List, Optional
from unittest import mock
from unittest.mock import MagicMock, patch

Expand Down

0 comments on commit 9166b5a

Please sign in to comment.