Skip to content

Commit e6cb7a8

Browse files
committedDec 9, 2024
testing
1 parent 99c6d80 commit e6cb7a8

File tree

8 files changed

+634
-4
lines changed

8 files changed

+634
-4
lines changed
 

‎.DS_Store

6 KB
Binary file not shown.

‎src/.DS_Store

6 KB
Binary file not shown.

‎src/functions.py

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from typing import Any, List
2+
from symai.components import ExceptionWithUsage, LengthConstrainedFunction
3+
from loguru import logger
4+
from symai import Symbol
5+
from pydantic import BaseModel, Field
6+
7+
8+
class ResultValidator(LengthConstrainedFunction):
9+
def __init__(
10+
self,
11+
validation_retry_count: int = 5,
12+
*args,
13+
**kwargs,
14+
):
15+
super().__init__(character_constraints=[], *args, **kwargs)
16+
self.validation_retry_count = validation_retry_count
17+
18+
def validate(self, result) -> List[str]:
19+
# validation_criteria = {
20+
# "Interview": "Does this summary identify different speakers and their key discussion points?",
21+
# "Keynote": "Does this summary include speaker details, their expertise, and key messages?",
22+
# "Scientific Paper": "Does this summary include methodology details and research findings?",
23+
# "Report": "Does this summary include specific numerical results and statistics?",
24+
# "Book": "Does this summary include character descriptions and their relationships?",
25+
# "Presentation Slides": "Does this summary include the core idea and value proposition?"
26+
# }
27+
28+
# # Get the validation prompt for this content type
29+
# if result.type in validation_criteria:
30+
# validation_prompt = validation_criteria[result.type]
31+
# else:
32+
# validation_prompt = "Is this a content summary?"
33+
34+
# validation = Symbol(f"{validation_prompt} Return yes or no.\n{result.summary}").interpret()
35+
# is_valid = "yes" in validation.lower() or "true" in validation.lower()
36+
37+
# print(f"Validation: {validation}")
38+
39+
# # Return a list of validation errors (empty list if valid)
40+
# return [] if is_valid else [f"Content type '{result.type}' validation failed"]
41+
42+
return []
43+
44+
def forward(self, *args, **kwargs):
45+
result, usage = super().forward(*args, **kwargs)
46+
47+
if self.validation_retry_count > 0:
48+
# save original task
49+
original_task = args[0]
50+
51+
# get list of seeds for remedy (to avoid same remedy for same input)
52+
remedy_seeds = self.prepare_seeds(self.validation_retry_count, **kwargs)
53+
54+
# validate the result
55+
for i in range(self.validation_retry_count):
56+
validation_errors = self.validate(result)
57+
58+
if len(validation_errors) > 0:
59+
for violation in validation_errors:
60+
logger.info(f"Validation error: {violation}")
61+
62+
logger.debug(str(result))
63+
64+
# build remedy task
65+
remedy_task = self.wrap_task(
66+
original_task, result.model_dump_json(), validation_errors
67+
)
68+
69+
# attempt to remedy the result
70+
kwargs["seed"] = remedy_seeds[i]
71+
result, remedy_usage = super().forward(remedy_task, *args[1:], **kwargs)
72+
73+
# update local usage
74+
usage.prompt_tokens += remedy_usage.prompt_tokens
75+
usage.completion_tokens += remedy_usage.completion_tokens
76+
usage.total_tokens += remedy_usage.total_tokens
77+
else:
78+
break
79+
80+
validation_errors = self.check_constraints(result)
81+
if i == self.validation_retry_count and len(validation_errors) > 0:
82+
raise ExceptionWithUsage(
83+
f"Failed to enforce constraints: {' | '.join(validation_errors)}",
84+
usage,
85+
)
86+
87+
return result, usage
88+
89+
def wrap_task(self, task: str, result: str, validation_errors: List[str]):
90+
joined_validation_errors = "\n".join(validation_errors)
91+
92+
remedy_task = f"""
93+
You had the following task:
94+
95+
[Original Task]
96+
{task}
97+
98+
[Original Output]
99+
{result}
100+
101+
However, the output has the following validation errors:
102+
103+
[Validation Errors]
104+
{joined_validation_errors}
105+
106+
[Task]
107+
Follow the origianl task but fix the validation errors.
108+
"""
109+
110+
return remedy_task
111+
112+
@property
113+
def static_context(self):
114+
return (
115+
"You are an agent for validating 'JSON' schemas and fixing errors."
116+
)
117+
118+
class LLMDataModel(BaseModel):
119+
"""
120+
A base class for Pydantic models that provides nicely formatted string output,
121+
suitable for LLM prompts, with support for nested models, lists, and optional section headers.
122+
"""
123+
124+
section_header: str = Field(
125+
default=None, exclude=True, frozen=True
126+
) # Optional section header for top-level models
127+
128+
def format_field(self, key: str, value: Any, indent: int = 0) -> str:
129+
"""
130+
Formats a single field for output. Handles nested models, lists, and dictionaries.
131+
"""
132+
indent_str = " " * indent
133+
if isinstance(value, LLMDataModel):
134+
# Nested model
135+
nested_str = value.__str__(indent + 2).strip()
136+
return f"{indent_str}{key}:\n{nested_str}" if key else nested_str
137+
elif isinstance(value, list):
138+
# List of items (handle nested models inside lists)
139+
formatted_items = "\n".join(
140+
[
141+
f"{indent_str} - {self.format_field('', item, indent).strip()}"
142+
for item in value
143+
]
144+
)
145+
return f"{indent_str}{key}:\n{formatted_items}" if key else formatted_items
146+
elif isinstance(value, dict):
147+
# Dictionary of key-value pairs
148+
formatted_items = "\n".join(
149+
[
150+
f"{indent_str} {k}: {self.format_field('', v, indent + 4).strip()}"
151+
for k, v in value.items()
152+
]
153+
)
154+
return f"{indent_str}{key}:\n{formatted_items}" if key else formatted_items
155+
else:
156+
# Primitive types
157+
return f"{indent_str}{key}: {value}" if key else f"{indent_str}{value}"
158+
159+
def __str__(self, indent: int = 0) -> str:
160+
"""
161+
Converts the model into a formatted string for LLM prompts.
162+
Handles indentation for nested models and includes an optional section header.
163+
"""
164+
indent_str = " " * indent
165+
fields = "\n".join(
166+
self.format_field(name, getattr(self, name), indent + 2)
167+
for name, field in self.model_fields.items()
168+
if (
169+
getattr(self, name, None) is not None
170+
and not getattr(field, "exclude", False)
171+
and not name == "section_header"
172+
) # Exclude None values and "exclude" fields
173+
)
174+
fields += "\n" # add line break at the end to separate from the next section
175+
176+
if self.section_header and indent == 0:
177+
header = f"{indent_str}[[{self.section_header}]]\n"
178+
return f"{header}{fields}"
179+
return fields

‎src/hierarchical.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77

88
from symai.components import FileReader, Function, ValidatedFunction
99
from symai.core_ext import bind
10+
from functions import LLMDataModel, ResultValidator
11+
12+
13+
class ChunkSummary(LLMDataModel):
14+
summary: str
15+
facts: List[str]
16+
type: str = None
17+
section_header: str = "CHUNK SUMMARY"
1018

1119

1220
class Summary(BaseModel):
@@ -43,6 +51,10 @@ def __init__(
4351
self.max_output_tokens = max_output_tokens
4452
self.content_types = content_types
4553
self.seed = seed
54+
self.result_validator = ResultValidator(
55+
data_model=Summary,
56+
validation_retry_count=3,
57+
)
4658

4759
file_content = None
4860
file_name = None
@@ -203,20 +215,36 @@ def summarize_chunks(self, chunks):
203215
chunk_summaries = []
204216
chunk_facts = []
205217

206-
for chunk in chunks:
218+
for i, chunk in enumerate(chunks):
219+
# Use ChunkSummary for individual chunk validation
207220
res, usage = super().forward(
208221
chunk,
209222
preview=False,
210223
response_format={"type": "json_object"},
211224
)
225+
226+
# Validate each chunk using LLMDataModel
227+
chunk_summary = ChunkSummary(
228+
summary=res.summary,
229+
facts=res.facts,
230+
type=self._content_type,
231+
)
232+
self.print_verbose(f"Chunk {i+1} Summary:\n{str(chunk_summary)}")
233+
212234
chunk_summaries.append(res.summary)
213235
chunk_facts.extend(res.facts)
214236

237+
# Create final summary
215238
res = Summary(
216239
summary="\n".join(chunk_summaries),
217240
facts=chunk_facts,
241+
type=self._content_type
218242
)
219-
return res, self.compute_required_tokens(res.summary, count_context=False)
243+
244+
# Validate entire summary using ResultValidator
245+
validated_res, validator_usage = self.result_validator(self.prompt)
246+
247+
return validated_res, self.compute_required_tokens(validated_res.summary, count_context=False)
220248

221249
def calculate_chunk_size(self, total_tokens):
222250
num_prompt_tokens = self.compute_required_tokens("", count_context=True)
@@ -328,8 +356,8 @@ def forward(self) -> Summary:
328356
res = Summary(
329357
summary=data,
330358
facts=facts,
359+
type=asset_type,
331360
)
332-
res.type = asset_type
333361
return res, self.get_usage()
334362
else:
335363
asset_type = self.get_asset_type(self.content)

0 commit comments

Comments
 (0)