Skip to content

Commit

Permalink
refactor approaches
Browse files Browse the repository at this point in the history
  • Loading branch information
= Enea_Gore committed Oct 30, 2024
1 parent 95e8492 commit 2483905
Show file tree
Hide file tree
Showing 17 changed files with 243 additions and 203 deletions.
41 changes: 29 additions & 12 deletions llm_core/llm_core/utils/predict_and_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ async def predict_and_parse(
chat_prompt: ChatPromptTemplate,
prompt_input: dict,
pydantic_object: Type[T],
tags: Optional[List[str]]
tags: Optional[List[str]],
use_function_calling: bool = False
) -> Optional[T]:
"""Predicts an LLM completion using the model and parses the output using the provided Pydantic model
Expand All @@ -36,14 +37,30 @@ async def predict_and_parse(
if experiment.run_id is not None:
tags.append(f"run-{experiment.run_id}")

structured_output_llm = model.with_structured_output(pydantic_object)
# chain = RunnableSequence(
# chat_prompt,
# structured_output_llm
# )
chain = chat_prompt | structured_output_llm

try:
return await chain.ainvoke(prompt_input, config={"tags": tags}) # type: ignore #
except ValidationError as e:
raise ValueError(f"Could not parse output: {e}") from e

if (use_function_calling):
structured_output_llm = model.with_structured_output(pydantic_object)
chain = chat_prompt | structured_output_llm

try:
result = await chain.ainvoke(prompt_input, config={"tags": tags})

if isinstance(result, pydantic_object):
return result
else:
raise ValueError("Parsed output does not match the expected Pydantic model.")

except ValidationError as e:
raise ValueError(f"Could not parse output: {e}") from e

else:
structured_output_llm = model.with_structured_output(pydantic_object, method = "json_mode")
chain = RunnableSequence(
chat_prompt,
structured_output_llm
)
try:
return await chain.ainvoke(prompt_input, config={"tags": tags})
except ValidationError as e:
raise ValueError(f"Could not parse output: {e}") from e

2 changes: 1 addition & 1 deletion modules/text/module_text_llm/module_text_llm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from module_text_llm.config import Configuration
from module_text_llm.evaluation import get_feedback_statistics, get_llm_statistics
from module_text_llm.generate_evaluation import generate_evaluation
from module_text_llm.approaches.approach_controller import generate_suggestions
from module_text_llm.approach_controller import generate_suggestions

@submissions_consumer
def receive_submissions(exercise: Exercise, submissions: List[Submission]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class ApproachType(str, Enum):

class ApproachConfig(BaseModel, ABC):
max_input_tokens: int = Field(default=3000, description="Maximum number of tokens in the input prompt.")
model: ModelConfigType = Field(default=DefaultModelConfig()) # type: ignore
type: ApproachType = Field(..., description="The type of approach config")
model: ModelConfigType = Field(default=DefaultModelConfig())
type: str = Field(..., description="The type of approach config")

class Config:
use_enum_values = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

from typing import List
from athena.text import Exercise, Submission, Feedback
from module_text_llm.basic_approach import BasicApproachConfig
from module_text_llm.chain_of_thought_approach import ChainOfThoughtConfig
from module_text_llm.approach_config import ApproachConfig

from module_text_llm.basic_approach.generate_suggestions import generate_suggestions as generate_suggestions_basic
from module_text_llm.chain_of_thought_approach.generate_suggestions import generate_suggestions as generate_cot_suggestions

async def generate_suggestions(exercise: Exercise, submission: Submission, config: ApproachConfig, debug: bool) -> List[Feedback]:
if(isinstance(config, BasicApproachConfig)):
return await generate_suggestions_basic(exercise, submission, config, debug)
elif(isinstance(config, ChainOfThoughtConfig)):
return await generate_cot_suggestions(exercise, submission, config, debug)

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from module_text_llm.approach_config import ApproachConfig
from pydantic import Field
from typing import Literal


from module_text_llm.basic_approach.prompt_generate_suggestions import GenerateSuggestionsPrompt

class BasicApproachConfig(ApproachConfig):
type: Literal['basic'] = 'basic'
generate_suggestions_prompt: GenerateSuggestionsPrompt = Field(default=GenerateSuggestionsPrompt())

Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List, Optional, Sequence
from pydantic import BaseModel, Field
from typing import List

from athena import emit_meta
from athena.text import Exercise, Submission, Feedback
Expand All @@ -9,29 +8,15 @@
check_prompt_length_and_omit_features_if_necessary,
num_tokens_from_prompt,
)
from athena.text import Exercise, Submission, Feedback
from llm_core.utils.predict_and_parse import predict_and_parse

from module_text_llm.config import BasicApproachConfig
from module_text_llm.helpers.utils import add_sentence_numbers, get_index_range_from_line_range, format_grading_instructions

class FeedbackModel(BaseModel):
title: str = Field(description="Very short title, i.e. feedback category or similar", example="Logic Error")
description: str = Field(description="Feedback description")
line_start: Optional[int] = Field(description="Referenced line number start, or empty if unreferenced")
line_end: Optional[int] = Field(description="Referenced line number end, or empty if unreferenced")
credits: float = Field(0.0, description="Number of points received/deducted")
grading_instruction_id: Optional[int] = Field(
description="ID of the grading instruction that was used to generate this feedback, or empty if no grading instruction was used"
)


class AssessmentModel(BaseModel):
"""Collection of feedbacks making up an assessment"""

feedbacks: List[FeedbackModel] = Field(description="Assessment feedbacks")
from module_text_llm.basic_approach.prompt_generate_suggestions import AssessmentModel

async def generate_suggestions(exercise: Exercise, submission: Submission, config: BasicApproachConfig, debug: bool) -> List[Feedback]:
model = config.model.get_model() # type: ignore[attr-defined]

prompt_input = {
"max_points": exercise.max_points,
"bonus_points": exercise.bonus_points,
Expand Down Expand Up @@ -74,7 +59,8 @@ async def generate_suggestions(exercise: Exercise, submission: Submission, confi
tags=[
f"exercise-{exercise.id}",
f"submission-{submission.id}",
]
],
use_function_calling=True
)

if debug:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from pydantic import Field, BaseModel
from typing import List, Optional
from pydantic import BaseModel, Field

system_message = """\
You are an AI tutor for text assessment at a prestigious university.
# Task
Create graded feedback suggestions for a student\'s text submission that a human tutor would accept. \
Meaning, the feedback you provide should be applicable to the submission with little to no modification.
# Style
1. Constructive, 2. Specific, 3. Balanced, 4. Clear and Concise, 5. Actionable, 6. Educational, 7. Contextual
# Problem statement
{problem_statement}
# Example solution
{example_solution}
# Grading instructions
{grading_instructions}
Max points: {max_points}, bonus points: {bonus_points}\
Respond in json.
"""

human_message = """\
Student\'s submission to grade (with sentence numbers <number>: <sentence>):
Respond in json.
\"\"\"
{submission}
\"\"\"\
"""

# Input Prompt
class GenerateSuggestionsPrompt(BaseModel):
"""\
Features available: **{problem_statement}**, **{example_solution}**, **{grading_instructions}**, **{max_points}**, **{bonus_points}**, **{submission}**
_Note: **{problem_statement}**, **{example_solution}**, or **{grading_instructions}** might be omitted if the input is too long._\
"""
system_message: str = Field(default=system_message,
description="Message for priming AI behavior and instructing it what to do.")
human_message: str = Field(default=human_message,
description="Message from a human. The input on which the AI is supposed to act.")
# Output Object
class FeedbackModel(BaseModel):
title: str = Field(description="Very short title, i.e. feedback category or similar", example="Logic Error")
description: str = Field(description="Feedback description")
line_start: Optional[int] = Field(description="Referenced line number start, or empty if unreferenced")
line_end: Optional[int] = Field(description="Referenced line number end, or empty if unreferenced")
credits: float = Field(0.0, description="Number of points received/deducted")
grading_instruction_id: Optional[int] = Field(
description="ID of the grading instruction that was used to generate this feedback, or empty if no grading instruction was used"
)


class AssessmentModel(BaseModel):
"""Collection of feedbacks making up an assessment"""

feedbacks: List[FeedbackModel] = Field(description="Assessment feedbacks")

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pydantic import BaseModel, Field
from typing import Literal
from llm_core.models import ModelConfigType, MiniModelConfig

from module_text_llm.approach_config import ApproachConfig
from module_text_llm.chain_of_thought_approach.prompt_generate_feedback import CoTGenerateSuggestionsPrompt
from module_text_llm.chain_of_thought_approach.prompt_thinking import ThinkingPrompt

class ChainOfThoughtConfig(ApproachConfig):
# Defaults to the cheaper mini 4o model
type: Literal['chain_of_thought'] = 'chain_of_thought'
model: ModelConfigType = Field(default=MiniModelConfig) # type: ignore
thikning_prompt: ThinkingPrompt = Field(default=ThinkingPrompt())
generate_suggestions_prompt: CoTGenerateSuggestionsPrompt = Field(default=CoTGenerateSuggestionsPrompt())

Loading

0 comments on commit 2483905

Please sign in to comment.