Skip to content

Commit

Permalink
reintroduce function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixTJDietrich committed Oct 11, 2024
1 parent 5f3eaa0 commit 4574e86
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions llm_core/llm_core/utils/llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Type, TypeVar, List
from pydantic import BaseModel
import tiktoken
from langchain.chat_models import ChatOpenAI
from langchain.base_language import BaseLanguageModel
from langchain.prompts import (
ChatPromptTemplate,
Expand Down Expand Up @@ -65,6 +66,18 @@ def check_prompt_length_and_omit_features_if_necessary(prompt: ChatPromptTemplat
return prompt_input, False


def supports_function_calling(model: BaseLanguageModel):
"""Returns True if the model supports function calling, False otherwise
Args:
model (BaseLanguageModel): The model to check
Returns:
boolean: True if the model supports function calling, False otherwise
"""
return isinstance(model, ChatOpenAI)


def get_chat_prompt_with_formatting_instructions(
model: BaseLanguageModel,
system_message: str,
Expand All @@ -84,9 +97,14 @@ def get_chat_prompt_with_formatting_instructions(
Returns:
ChatPromptTemplate: ChatPromptTemplate with formatting instructions (if necessary)
"""
if supports_function_calling(model):
system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message)
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

output_parser = PydanticOutputParser(pydantic_object=pydantic_object)
system_message_prompt = SystemMessagePromptTemplate.from_template(system_message + "\n\n{format_instructions}")
system_message_prompt = SystemMessagePromptTemplate.from_template(system_message + "\n{format_instructions}")
system_message_prompt.prompt.partial_variables = {"format_instructions": output_parser.get_format_instructions()}
system_message_prompt.prompt.input_variables.remove("format_instructions")
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message + "\n\nJSON response following the provided schema:")
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

0 comments on commit 4574e86

Please sign in to comment.