Skip to content

Commit

Permalink
Use init_chat_model() helper method to initialise models (#87)
Browse files Browse the repository at this point in the history
* Upgrade langchain version to 0.2.8

* Use new init_chat_model function to initialise any models. Remove models.py since its redundant

* Add migration file to convert all provider names in member table to new name format

* Update default provider and model names when creating new member

* Update frontend available model provider names

* Update test_members to use new provider name
  • Loading branch information
StreetLamb authored Jul 27, 2024
1 parent 890b4b4 commit 9409207
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Rename provider names in member table to new name
Revision ID: 6e7c33ddf30f
Revises: 0a354b5c6f6c
Create Date: 2024-07-27 04:29:51.886906
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '6e7c33ddf30f'
down_revision = '0a354b5c6f6c'
branch_labels = None
depends_on = None


def upgrade():
# Mapping of old provider names to new provider names
mapping = {
'ChatOpenAI': 'openai',
'ChatAnthropic': 'anthropic',
}

# Rename each provider name according to the mapping
for old_name, new_name in mapping.items():
op.execute(f"UPDATE member SET provider = '{new_name}' WHERE provider = '{old_name}'")


def downgrade():
# Mapping of new provider names back to old provider names
mapping = {
'openai': 'ChatOpenAI',
'anthropic': 'ChatAnthropic',
}

# Revert each provider name according to the mapping
for new_name, old_name in mapping.items():
op.execute(f"UPDATE member SET provider = '{old_name}' WHERE provider = '{new_name}'")
14 changes: 7 additions & 7 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Annotated, Any

from langchain.chat_models import init_chat_model
from langchain.tools.retriever import create_retriever_tool
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
Expand All @@ -16,7 +17,6 @@
from pydantic import BaseModel, Field
from typing_extensions import NotRequired, TypedDict

from app.core.graph.models import all_models
from app.core.graph.rag.qdrant import QdrantStore
from app.core.graph.skills import managed_skills
from app.core.graph.skills.api_tool import dynamic_api_tool
Expand Down Expand Up @@ -147,12 +147,12 @@ class ReturnTeamState(TypedDict):

class BaseNode:
def __init__(self, provider: str, model: str, temperature: float):
self.model = all_models[provider](
model=model, temperature=temperature, streaming=True
) # type: ignore[call-arg]
self.final_answer_model = all_models[provider](
model=model, temperature=0, streaming=True
) # type: ignore[call-arg]
self.model = init_chat_model(
model, model_provider=provider, temperature=temperature, streaming=True
)
self.final_answer_model = init_chat_model(
model, model_provider=provider, temperature=0, streaming=True
)

def tag_with_name(self, ai_message: AIMessage, name: str) -> AIMessage:
"""Tag a name to the AI message"""
Expand Down
14 changes: 0 additions & 14 deletions backend/app/core/graph/models.py

This file was deleted.

4 changes: 2 additions & 2 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ class MemberBase(SQLModel):
position_x: float
position_y: float
source: int | None = None
provider: str = "ChatOpenAI"
model: str = "gpt-3.5-turbo"
provider: str = "openai"
model: str = "gpt-4o-mini"
temperature: float = 0.7
interrupt: bool = False

Expand Down
14 changes: 7 additions & 7 deletions backend/app/tests/api/routes/test_members.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_read_members(
"position_x": 0.0,
"position_y": 0.0,
"source": None,
"provider": "ChatOpenAI",
"provider": "openai",
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"interrupt": False,
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_read_member(
"position_x": 0.0,
"position_y": 0.0,
"source": None,
"provider": "ChatOpenAI",
"provider": "openai",
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"interrupt": False,
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_create_member(
"position_x": 0.0,
"position_y": 0.0,
"source": None,
"provider": "ChatOpenAI",
"provider": "openai",
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"interrupt": False,
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_create_member_duplicate_name(
"position_x": 0.0,
"position_y": 0.0,
"source": None,
"provider": "ChatOpenAI",
"provider": "openai",
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"interrupt": False,
Expand All @@ -160,7 +160,7 @@ def test_create_member_duplicate_name(
"position_x": 0.0,
"position_y": 0.0,
"source": None,
"provider": "ChatOpenAI",
"provider": "openai",
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"interrupt": False,
Expand All @@ -187,7 +187,7 @@ def test_update_member(
"position_x": 0.0,
"position_y": 0.0,
"source": None,
"provider": "ChatOpenAI",
"provider": "openai",
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"interrupt": False,
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_delete_member(
"position_x": 0.0,
"position_y": 0.0,
"source": None,
"provider": "ChatOpenAI",
"provider": "openai",
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"interrupt": False,
Expand Down
10 changes: 5 additions & 5 deletions backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ langgraph = "0.1.9"
langserve = {extras = ["server"], version = "^0.0.51"}
langchain-openai = "0.1.17"
grandalf = "^0.8"
langchain = "0.2.7"
langchain = "0.2.8"
langchain-community = "0.2.7"
duckduckgo-search = "6.1.0"
wikipedia = "^1.4.0"
Expand Down
6 changes: 2 additions & 4 deletions frontend/src/components/Members/EditMember.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,12 @@ const customSelectOption = {

// TODO: Place this somewhere else.
const AVAILABLE_MODELS = {
ChatOpenAI: ["gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"],
ChatAnthropic: [
openai: ["gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"],
anthropic: [
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
],
// ChatCohere: ["command"],
// ChatGoogleGenerativeAI: ["gemini-pro"],
}

type ModelProvider = keyof typeof AVAILABLE_MODELS
Expand Down

0 comments on commit 9409207

Please sign in to comment.