From e49940b0d88b049d3954d974acc1b3b463cbb387 Mon Sep 17 00:00:00 2001 From: = Enea_Gore Date: Wed, 9 Oct 2024 20:42:05 +0200 Subject: [PATCH] refactor predict and parse --- modules/llm_core/llm_core/utils/llm_utils.py | 51 ++----------------- .../llm_core/utils/predict_and_parse.py | 48 +++++++++++++++++ .../core/filter_feedback.py | 2 +- .../core/generate_suggestions.py | 2 +- .../get_structured_grading_instructions.py | 2 +- .../generate_graded_suggestions_by_file.py | 3 +- ...generate_non_graded_suggestions_by_file.py | 3 +- .../generate_summary_by_file.py | 3 +- .../split_grading_instructions_by_file.py | 3 +- .../split_problem_statement_by_file.py | 3 +- 10 files changed, 65 insertions(+), 55 deletions(-) create mode 100644 modules/llm_core/llm_core/utils/predict_and_parse.py diff --git a/modules/llm_core/llm_core/utils/llm_utils.py b/modules/llm_core/llm_core/utils/llm_utils.py index f779cb55..4637b855 100644 --- a/modules/llm_core/llm_core/utils/llm_utils.py +++ b/modules/llm_core/llm_core/utils/llm_utils.py @@ -1,5 +1,5 @@ -from typing import Optional, Type, TypeVar, List -from pydantic import BaseModel, ValidationError +from typing import Type, TypeVar, List +from pydantic import BaseModel import tiktoken from langchain.chat_models import ChatOpenAI from langchain.base_language import BaseLanguageModel @@ -9,9 +9,7 @@ HumanMessagePromptTemplate, ) from langchain.output_parsers import PydanticOutputParser -from langchain_core.runnables import RunnableSequence - -from athena import emit_meta, get_experiment_environment +from athena import emit_meta T = TypeVar("T", bound=BaseModel) @@ -109,45 +107,4 @@ def get_chat_prompt_with_formatting_instructions( system_message_prompt.prompt.partial_variables = {"format_instructions": output_parser.get_format_instructions()} system_message_prompt.prompt.input_variables.remove("format_instructions") human_message_prompt = HumanMessagePromptTemplate.from_template(human_message + "\n\nJSON response following the provided schema:") - return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) - - -async def predict_and_parse( - model: BaseLanguageModel, - chat_prompt: ChatPromptTemplate, - prompt_input: dict, - pydantic_object: Type[T], - tags: Optional[List[str]] - ) -> Optional[T]: - """Predicts an LLM completion using the model and parses the output using the provided Pydantic model - - Args: - model (BaseLanguageModel): The model to predict with - chat_prompt (ChatPromptTemplate): Prompt to use - prompt_input (dict): Input parameters to use for the prompt - pydantic_object (Type[T]): Pydantic model to parse the output - tags (Optional[List[str]]: List of tags to tag the prediction with - - Returns: - Optional[T]: Parsed output, or None if it could not be parsed - """ - experiment = get_experiment_environment() - - tags = tags or [] - if experiment.experiment_id is not None: - tags.append(f"experiment-{experiment.experiment_id}") - if experiment.module_configuration_id is not None: - tags.append(f"module-configuration-{experiment.module_configuration_id}") - if experiment.run_id is not None: - tags.append(f"run-{experiment.run_id}") - - structured_output_llm = model.with_structured_output(pydantic_object, method="json_mode") - chain = RunnableSequence( - chat_prompt, - structured_output_llm - ) - - try: - return await chain.ainvoke(prompt_input, config={"tags": tags}) - except ValidationError as e: - raise ValueError(f"Could not parse output: {e}") from e + return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) \ No newline at end of file diff --git a/modules/llm_core/llm_core/utils/predict_and_parse.py b/modules/llm_core/llm_core/utils/predict_and_parse.py new file mode 100644 index 00000000..d73748bd --- /dev/null +++ b/modules/llm_core/llm_core/utils/predict_and_parse.py @@ -0,0 +1,48 @@ +from typing import Optional, Type, TypeVar, List +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel, ValidationError +from langchain_core.runnables import RunnableSequence +from athena import get_experiment_environment + +T = TypeVar("T", bound=BaseModel) + +async def predict_and_parse( + model: BaseLanguageModel, + chat_prompt: ChatPromptTemplate, + prompt_input: dict, + pydantic_object: Type[T], + tags: Optional[List[str]] + ) -> Optional[T]: + """Predicts an LLM completion using the model and parses the output using the provided Pydantic model + + Args: + model (BaseLanguageModel): The model to predict with + chat_prompt (ChatPromptTemplate): Prompt to use + prompt_input (dict): Input parameters to use for the prompt + pydantic_object (Type[T]): Pydantic model to parse the output + tags (Optional[List[str]]: List of tags to tag the prediction with + + Returns: + Optional[T]: Parsed output, or None if it could not be parsed + """ + experiment = get_experiment_environment() + + tags = tags or [] + if experiment.experiment_id is not None: + tags.append(f"experiment-{experiment.experiment_id}") + if experiment.module_configuration_id is not None: + tags.append(f"module-configuration-{experiment.module_configuration_id}") + if experiment.run_id is not None: + tags.append(f"run-{experiment.run_id}") + + structured_output_llm = model.with_structured_output(pydantic_object, method="json_mode") + chain = RunnableSequence( + chat_prompt, + structured_output_llm + ) + + try: + return await chain.ainvoke(prompt_input, config={"tags": tags}) + except ValidationError as e: + raise ValueError(f"Could not parse output: {e}") from e \ No newline at end of file diff --git a/modules/modeling/module_modeling_llm/module_modeling_llm/core/filter_feedback.py b/modules/modeling/module_modeling_llm/module_modeling_llm/core/filter_feedback.py index 0736efc1..efe2b187 100644 --- a/modules/modeling/module_modeling_llm/module_modeling_llm/core/filter_feedback.py +++ b/modules/modeling/module_modeling_llm/module_modeling_llm/core/filter_feedback.py @@ -3,7 +3,7 @@ from athena import emit_meta from module_modeling_llm.config import BasicApproachConfig -from llm_core.utils.llm_utils import predict_and_parse +from llm_core.utils.predict_and_parse import predict_and_parse from module_modeling_llm.models.assessment_model import AssessmentModel from module_modeling_llm.models.exercise_model import ExerciseModel from module_modeling_llm.prompts.filter_feedback_prompt import FilterFeedbackInputs diff --git a/modules/modeling/module_modeling_llm/module_modeling_llm/core/generate_suggestions.py b/modules/modeling/module_modeling_llm/module_modeling_llm/core/generate_suggestions.py index 346e871a..6db69e7d 100644 --- a/modules/modeling/module_modeling_llm/module_modeling_llm/core/generate_suggestions.py +++ b/modules/modeling/module_modeling_llm/module_modeling_llm/core/generate_suggestions.py @@ -6,7 +6,7 @@ from module_modeling_llm.config import BasicApproachConfig from module_modeling_llm.models.assessment_model import AssessmentModel from module_modeling_llm.prompts.apollon_format_description import apollon_format_description -from llm_core.utils.llm_utils import predict_and_parse +from llm_core.utils.predict_and_parse import predict_and_parse from module_modeling_llm.prompts.graded_feedback_prompt import GradedFeedbackInputs from module_modeling_llm.models.exercise_model import ExerciseModel diff --git a/modules/modeling/module_modeling_llm/module_modeling_llm/core/get_structured_grading_instructions.py b/modules/modeling/module_modeling_llm/module_modeling_llm/core/get_structured_grading_instructions.py index 1837bd31..ae84e0dd 100644 --- a/modules/modeling/module_modeling_llm/module_modeling_llm/core/get_structured_grading_instructions.py +++ b/modules/modeling/module_modeling_llm/module_modeling_llm/core/get_structured_grading_instructions.py @@ -4,7 +4,7 @@ from langchain_core.prompts import ChatPromptTemplate from athena.schemas.grading_criterion import GradingCriterion, StructuredGradingCriterion -from llm_core.utils.llm_utils import predict_and_parse +from llm_core.utils.predict_and_parse import predict_and_parse from module_modeling_llm.config import BasicApproachConfig from module_modeling_llm.models.exercise_model import ExerciseModel from module_modeling_llm.prompts.structured_grading_instructions_prompt import StructuredGradingInstructionsInputs diff --git a/modules/programming/module_programming_llm/module_programming_llm/generate_graded_suggestions_by_file.py b/modules/programming/module_programming_llm/module_programming_llm/generate_graded_suggestions_by_file.py index 1d9d95f2..4a29b8f6 100644 --- a/modules/programming/module_programming_llm/module_programming_llm/generate_graded_suggestions_by_file.py +++ b/modules/programming/module_programming_llm/module_programming_llm/generate_graded_suggestions_by_file.py @@ -17,8 +17,9 @@ check_prompt_length_and_omit_features_if_necessary, get_chat_prompt_with_formatting_instructions, num_tokens_from_string, - predict_and_parse, ) +from llm_core.utils.predict_and_parse import predict_and_parse + from module_programming_llm.helpers.utils import ( get_diff, load_files_from_repo, diff --git a/modules/programming/module_programming_llm/module_programming_llm/generate_non_graded_suggestions_by_file.py b/modules/programming/module_programming_llm/module_programming_llm/generate_non_graded_suggestions_by_file.py index f61401b8..de654a42 100644 --- a/modules/programming/module_programming_llm/module_programming_llm/generate_non_graded_suggestions_by_file.py +++ b/modules/programming/module_programming_llm/module_programming_llm/generate_non_graded_suggestions_by_file.py @@ -15,8 +15,9 @@ check_prompt_length_and_omit_features_if_necessary, get_chat_prompt_with_formatting_instructions, num_tokens_from_string, - predict_and_parse, ) +from llm_core.utils.predict_and_parse import predict_and_parse + from module_programming_llm.helpers.utils import ( get_diff, load_files_from_repo, diff --git a/modules/programming/module_programming_llm/module_programming_llm/generate_summary_by_file.py b/modules/programming/module_programming_llm/module_programming_llm/generate_summary_by_file.py index e5e09613..e5ac6aad 100644 --- a/modules/programming/module_programming_llm/module_programming_llm/generate_summary_by_file.py +++ b/modules/programming/module_programming_llm/module_programming_llm/generate_summary_by_file.py @@ -12,8 +12,9 @@ from llm_core.utils.llm_utils import ( get_chat_prompt_with_formatting_instructions, num_tokens_from_prompt, - predict_and_parse, ) +from llm_core.utils.predict_and_parse import predict_and_parse + from module_programming_llm.helpers.utils import ( get_diff, load_files_from_repo, diff --git a/modules/programming/module_programming_llm/module_programming_llm/split_grading_instructions_by_file.py b/modules/programming/module_programming_llm/module_programming_llm/split_grading_instructions_by_file.py index 4e5d7952..08c08924 100644 --- a/modules/programming/module_programming_llm/module_programming_llm/split_grading_instructions_by_file.py +++ b/modules/programming/module_programming_llm/module_programming_llm/split_grading_instructions_by_file.py @@ -12,8 +12,9 @@ get_chat_prompt_with_formatting_instructions, num_tokens_from_string, num_tokens_from_prompt, - predict_and_parse ) +from llm_core.utils.predict_and_parse import predict_and_parse + from module_programming_llm.helpers.utils import format_grading_instructions, get_diff diff --git a/modules/programming/module_programming_llm/module_programming_llm/split_problem_statement_by_file.py b/modules/programming/module_programming_llm/module_programming_llm/split_problem_statement_by_file.py index 3ef7bd08..aecf516a 100644 --- a/modules/programming/module_programming_llm/module_programming_llm/split_problem_statement_by_file.py +++ b/modules/programming/module_programming_llm/module_programming_llm/split_problem_statement_by_file.py @@ -12,8 +12,9 @@ get_chat_prompt_with_formatting_instructions, num_tokens_from_string, num_tokens_from_prompt, - predict_and_parse ) +from llm_core.utils.predict_and_parse import predict_and_parse + from module_programming_llm.helpers.utils import get_diff