1
+ from typing import Any , List
2
+ from symai .components import ExceptionWithUsage , LengthConstrainedFunction
3
+ from loguru import logger
4
+ from symai import Symbol
5
+ from pydantic import BaseModel , Field
6
+
7
+
8
+ class ResultValidator (LengthConstrainedFunction ):
9
+ def __init__ (
10
+ self ,
11
+ validation_retry_count : int = 5 ,
12
+ * args ,
13
+ ** kwargs ,
14
+ ):
15
+ super ().__init__ (character_constraints = [], * args , ** kwargs )
16
+ self .validation_retry_count = validation_retry_count
17
+
18
+ def validate (self , result ) -> List [str ]:
19
+ # validation_criteria = {
20
+ # "Interview": "Does this summary identify different speakers and their key discussion points?",
21
+ # "Keynote": "Does this summary include speaker details, their expertise, and key messages?",
22
+ # "Scientific Paper": "Does this summary include methodology details and research findings?",
23
+ # "Report": "Does this summary include specific numerical results and statistics?",
24
+ # "Book": "Does this summary include character descriptions and their relationships?",
25
+ # "Presentation Slides": "Does this summary include the core idea and value proposition?"
26
+ # }
27
+
28
+ # # Get the validation prompt for this content type
29
+ # if result.type in validation_criteria:
30
+ # validation_prompt = validation_criteria[result.type]
31
+ # else:
32
+ # validation_prompt = "Is this a content summary?"
33
+
34
+ # validation = Symbol(f"{validation_prompt} Return yes or no.\n{result.summary}").interpret()
35
+ # is_valid = "yes" in validation.lower() or "true" in validation.lower()
36
+
37
+ # print(f"Validation: {validation}")
38
+
39
+ # # Return a list of validation errors (empty list if valid)
40
+ # return [] if is_valid else [f"Content type '{result.type}' validation failed"]
41
+
42
+ return []
43
+
44
+ def forward (self , * args , ** kwargs ):
45
+ result , usage = super ().forward (* args , ** kwargs )
46
+
47
+ if self .validation_retry_count > 0 :
48
+ # save original task
49
+ original_task = args [0 ]
50
+
51
+ # get list of seeds for remedy (to avoid same remedy for same input)
52
+ remedy_seeds = self .prepare_seeds (self .validation_retry_count , ** kwargs )
53
+
54
+ # validate the result
55
+ for i in range (self .validation_retry_count ):
56
+ validation_errors = self .validate (result )
57
+
58
+ if len (validation_errors ) > 0 :
59
+ for violation in validation_errors :
60
+ logger .info (f"Validation error: { violation } " )
61
+
62
+ logger .debug (str (result ))
63
+
64
+ # build remedy task
65
+ remedy_task = self .wrap_task (
66
+ original_task , result .model_dump_json (), validation_errors
67
+ )
68
+
69
+ # attempt to remedy the result
70
+ kwargs ["seed" ] = remedy_seeds [i ]
71
+ result , remedy_usage = super ().forward (remedy_task , * args [1 :], ** kwargs )
72
+
73
+ # update local usage
74
+ usage .prompt_tokens += remedy_usage .prompt_tokens
75
+ usage .completion_tokens += remedy_usage .completion_tokens
76
+ usage .total_tokens += remedy_usage .total_tokens
77
+ else :
78
+ break
79
+
80
+ validation_errors = self .check_constraints (result )
81
+ if i == self .validation_retry_count and len (validation_errors ) > 0 :
82
+ raise ExceptionWithUsage (
83
+ f"Failed to enforce constraints: { ' | ' .join (validation_errors )} " ,
84
+ usage ,
85
+ )
86
+
87
+ return result , usage
88
+
89
+ def wrap_task (self , task : str , result : str , validation_errors : List [str ]):
90
+ joined_validation_errors = "\n " .join (validation_errors )
91
+
92
+ remedy_task = f"""
93
+ You had the following task:
94
+
95
+ [Original Task]
96
+ { task }
97
+
98
+ [Original Output]
99
+ { result }
100
+
101
+ However, the output has the following validation errors:
102
+
103
+ [Validation Errors]
104
+ { joined_validation_errors }
105
+
106
+ [Task]
107
+ Follow the origianl task but fix the validation errors.
108
+ """
109
+
110
+ return remedy_task
111
+
112
+ @property
113
+ def static_context (self ):
114
+ return (
115
+ "You are an agent for validating 'JSON' schemas and fixing errors."
116
+ )
117
+
118
+ class LLMDataModel (BaseModel ):
119
+ """
120
+ A base class for Pydantic models that provides nicely formatted string output,
121
+ suitable for LLM prompts, with support for nested models, lists, and optional section headers.
122
+ """
123
+
124
+ section_header : str = Field (
125
+ default = None , exclude = True , frozen = True
126
+ ) # Optional section header for top-level models
127
+
128
+ def format_field (self , key : str , value : Any , indent : int = 0 ) -> str :
129
+ """
130
+ Formats a single field for output. Handles nested models, lists, and dictionaries.
131
+ """
132
+ indent_str = " " * indent
133
+ if isinstance (value , LLMDataModel ):
134
+ # Nested model
135
+ nested_str = value .__str__ (indent + 2 ).strip ()
136
+ return f"{ indent_str } { key } :\n { nested_str } " if key else nested_str
137
+ elif isinstance (value , list ):
138
+ # List of items (handle nested models inside lists)
139
+ formatted_items = "\n " .join (
140
+ [
141
+ f"{ indent_str } - { self .format_field ('' , item , indent ).strip ()} "
142
+ for item in value
143
+ ]
144
+ )
145
+ return f"{ indent_str } { key } :\n { formatted_items } " if key else formatted_items
146
+ elif isinstance (value , dict ):
147
+ # Dictionary of key-value pairs
148
+ formatted_items = "\n " .join (
149
+ [
150
+ f"{ indent_str } { k } : { self .format_field ('' , v , indent + 4 ).strip ()} "
151
+ for k , v in value .items ()
152
+ ]
153
+ )
154
+ return f"{ indent_str } { key } :\n { formatted_items } " if key else formatted_items
155
+ else :
156
+ # Primitive types
157
+ return f"{ indent_str } { key } : { value } " if key else f"{ indent_str } { value } "
158
+
159
+ def __str__ (self , indent : int = 0 ) -> str :
160
+ """
161
+ Converts the model into a formatted string for LLM prompts.
162
+ Handles indentation for nested models and includes an optional section header.
163
+ """
164
+ indent_str = " " * indent
165
+ fields = "\n " .join (
166
+ self .format_field (name , getattr (self , name ), indent + 2 )
167
+ for name , field in self .model_fields .items ()
168
+ if (
169
+ getattr (self , name , None ) is not None
170
+ and not getattr (field , "exclude" , False )
171
+ and not name == "section_header"
172
+ ) # Exclude None values and "exclude" fields
173
+ )
174
+ fields += "\n " # add line break at the end to separate from the next section
175
+
176
+ if self .section_header and indent == 0 :
177
+ header = f"{ indent_str } [[{ self .section_header } ]]\n "
178
+ return f"{ header } { fields } "
179
+ return fields
0 commit comments