Skip to content

Commit

Permalink
refactor predict and parse
Browse files Browse the repository at this point in the history
  • Loading branch information
= Enea_Gore committed Oct 9, 2024
1 parent b92e316 commit e49940b
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 55 deletions.
51 changes: 4 additions & 47 deletions modules/llm_core/llm_core/utils/llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional, Type, TypeVar, List
from pydantic import BaseModel, ValidationError
from typing import Type, TypeVar, List
from pydantic import BaseModel
import tiktoken
from langchain.chat_models import ChatOpenAI
from langchain.base_language import BaseLanguageModel
Expand All @@ -9,9 +9,7 @@
HumanMessagePromptTemplate,
)
from langchain.output_parsers import PydanticOutputParser
from langchain_core.runnables import RunnableSequence

from athena import emit_meta, get_experiment_environment
from athena import emit_meta

T = TypeVar("T", bound=BaseModel)

Expand Down Expand Up @@ -109,45 +107,4 @@ def get_chat_prompt_with_formatting_instructions(
system_message_prompt.prompt.partial_variables = {"format_instructions": output_parser.get_format_instructions()}
system_message_prompt.prompt.input_variables.remove("format_instructions")
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message + "\n\nJSON response following the provided schema:")
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])


async def predict_and_parse(
model: BaseLanguageModel,
chat_prompt: ChatPromptTemplate,
prompt_input: dict,
pydantic_object: Type[T],
tags: Optional[List[str]]
) -> Optional[T]:
"""Predicts an LLM completion using the model and parses the output using the provided Pydantic model
Args:
model (BaseLanguageModel): The model to predict with
chat_prompt (ChatPromptTemplate): Prompt to use
prompt_input (dict): Input parameters to use for the prompt
pydantic_object (Type[T]): Pydantic model to parse the output
tags (Optional[List[str]]: List of tags to tag the prediction with
Returns:
Optional[T]: Parsed output, or None if it could not be parsed
"""
experiment = get_experiment_environment()

tags = tags or []
if experiment.experiment_id is not None:
tags.append(f"experiment-{experiment.experiment_id}")
if experiment.module_configuration_id is not None:
tags.append(f"module-configuration-{experiment.module_configuration_id}")
if experiment.run_id is not None:
tags.append(f"run-{experiment.run_id}")

structured_output_llm = model.with_structured_output(pydantic_object, method="json_mode")
chain = RunnableSequence(
chat_prompt,
structured_output_llm
)

try:
return await chain.ainvoke(prompt_input, config={"tags": tags})
except ValidationError as e:
raise ValueError(f"Could not parse output: {e}") from e
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
48 changes: 48 additions & 0 deletions modules/llm_core/llm_core/utils/predict_and_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional, Type, TypeVar, List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import RunnableSequence
from athena import get_experiment_environment

T = TypeVar("T", bound=BaseModel)

async def predict_and_parse(
model: BaseLanguageModel,
chat_prompt: ChatPromptTemplate,
prompt_input: dict,
pydantic_object: Type[T],
tags: Optional[List[str]]
) -> Optional[T]:
"""Predicts an LLM completion using the model and parses the output using the provided Pydantic model
Args:
model (BaseLanguageModel): The model to predict with
chat_prompt (ChatPromptTemplate): Prompt to use
prompt_input (dict): Input parameters to use for the prompt
pydantic_object (Type[T]): Pydantic model to parse the output
tags (Optional[List[str]]: List of tags to tag the prediction with
Returns:
Optional[T]: Parsed output, or None if it could not be parsed
"""
experiment = get_experiment_environment()

tags = tags or []
if experiment.experiment_id is not None:
tags.append(f"experiment-{experiment.experiment_id}")
if experiment.module_configuration_id is not None:
tags.append(f"module-configuration-{experiment.module_configuration_id}")
if experiment.run_id is not None:
tags.append(f"run-{experiment.run_id}")

structured_output_llm = model.with_structured_output(pydantic_object, method="json_mode")
chain = RunnableSequence(
chat_prompt,
structured_output_llm
)

try:
return await chain.ainvoke(prompt_input, config={"tags": tags})
except ValidationError as e:
raise ValueError(f"Could not parse output: {e}") from e
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from athena import emit_meta
from module_modeling_llm.config import BasicApproachConfig
from llm_core.utils.llm_utils import predict_and_parse
from llm_core.utils.predict_and_parse import predict_and_parse
from module_modeling_llm.models.assessment_model import AssessmentModel
from module_modeling_llm.models.exercise_model import ExerciseModel
from module_modeling_llm.prompts.filter_feedback_prompt import FilterFeedbackInputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from module_modeling_llm.config import BasicApproachConfig
from module_modeling_llm.models.assessment_model import AssessmentModel
from module_modeling_llm.prompts.apollon_format_description import apollon_format_description
from llm_core.utils.llm_utils import predict_and_parse
from llm_core.utils.predict_and_parse import predict_and_parse
from module_modeling_llm.prompts.graded_feedback_prompt import GradedFeedbackInputs
from module_modeling_llm.models.exercise_model import ExerciseModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langchain_core.prompts import ChatPromptTemplate

from athena.schemas.grading_criterion import GradingCriterion, StructuredGradingCriterion
from llm_core.utils.llm_utils import predict_and_parse
from llm_core.utils.predict_and_parse import predict_and_parse
from module_modeling_llm.config import BasicApproachConfig
from module_modeling_llm.models.exercise_model import ExerciseModel
from module_modeling_llm.prompts.structured_grading_instructions_prompt import StructuredGradingInstructionsInputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
check_prompt_length_and_omit_features_if_necessary,
get_chat_prompt_with_formatting_instructions,
num_tokens_from_string,
predict_and_parse,
)
from llm_core.utils.predict_and_parse import predict_and_parse

from module_programming_llm.helpers.utils import (
get_diff,
load_files_from_repo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
check_prompt_length_and_omit_features_if_necessary,
get_chat_prompt_with_formatting_instructions,
num_tokens_from_string,
predict_and_parse,
)
from llm_core.utils.predict_and_parse import predict_and_parse

from module_programming_llm.helpers.utils import (
get_diff,
load_files_from_repo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from llm_core.utils.llm_utils import (
get_chat_prompt_with_formatting_instructions,
num_tokens_from_prompt,
predict_and_parse,
)
from llm_core.utils.predict_and_parse import predict_and_parse

from module_programming_llm.helpers.utils import (
get_diff,
load_files_from_repo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
get_chat_prompt_with_formatting_instructions,
num_tokens_from_string,
num_tokens_from_prompt,
predict_and_parse
)
from llm_core.utils.predict_and_parse import predict_and_parse

from module_programming_llm.helpers.utils import format_grading_instructions, get_diff


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
get_chat_prompt_with_formatting_instructions,
num_tokens_from_string,
num_tokens_from_prompt,
predict_and_parse
)
from llm_core.utils.predict_and_parse import predict_and_parse

from module_programming_llm.helpers.utils import get_diff


Expand Down

0 comments on commit e49940b

Please sign in to comment.