|
29 | 29 | - occupancy-limited |
30 | 30 | - latency-bound |
31 | 31 |
|
32 | | -Metric definitions are in metric_schema.py and rendering logic is in section_renderers.py. |
| 32 | +Metric definitions are in metric_schema.py. |
33 | 33 | """ |
34 | 34 |
|
35 | | -from typing import Any, Dict, Optional, Tuple |
| 35 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
36 | 36 |
|
37 | | -from .section_renderers import ( |
38 | | - create_metric_getter, |
39 | | - render_gpu_specs, |
40 | | - render_kernel_code, |
41 | | - render_ncu_metrics, |
42 | | - render_problem_description, |
43 | | - render_task_instructions, |
44 | | -) |
| 37 | +from .metric_schema import GPU_MEMORY_FIELDS, GPU_SPEC_FIELDS, NCU_METRIC_SECTIONS |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | +# ============================================================================= |
| 42 | +# Section Renderers |
| 43 | +# ============================================================================= |
| 44 | + |
| 45 | + |
| 46 | +def render_problem_description(problem_description: str) -> List[str]: |
| 47 | + """Render the problem description section.""" |
| 48 | + return ["## Problem Description", "", problem_description] |
| 49 | + |
| 50 | + |
| 51 | +def render_kernel_code(kernel_code: str, language: str = "python") -> List[str]: |
| 52 | + """Render the kernel code section with syntax highlighting.""" |
| 53 | + return ["", "## Current Kernel Code", "", f"```{language}", kernel_code, "```"] |
| 54 | + |
| 55 | + |
| 56 | +def render_gpu_specs(gpu_specs: Dict[str, Any]) -> List[str]: |
| 57 | + """Render the GPU hardware specifications section.""" |
| 58 | + lines = ["", "## GPU Hardware Specifications", ""] |
| 59 | + |
| 60 | + for label, key, unit in GPU_SPEC_FIELDS: |
| 61 | + value = gpu_specs.get(key, "N/A") |
| 62 | + lines.append(f"- **{label}:** {value}{unit}") |
| 63 | + |
| 64 | + for label, size_key, type_key, size_unit in GPU_MEMORY_FIELDS: |
| 65 | + size_value = gpu_specs.get(size_key, "N/A") |
| 66 | + type_value = gpu_specs.get(type_key, "") |
| 67 | + lines.append(f"- **{label}:** {size_value}{size_unit} {type_value}") |
| 68 | + |
| 69 | + return lines |
| 70 | + |
| 71 | + |
| 72 | +def render_ncu_metrics( |
| 73 | + ncu_metrics: Dict[str, Any], |
| 74 | + get_metric_fn: Callable[[str, str], str], |
| 75 | +) -> List[str]: |
| 76 | + """Render the NCU profiling metrics section.""" |
| 77 | + lines = ["", "## NCU Profiling Metrics"] |
| 78 | + |
| 79 | + for section_name, metrics in NCU_METRIC_SECTIONS.items(): |
| 80 | + lines.append("") |
| 81 | + lines.append(f"### {section_name}") |
| 82 | + for label, key, unit in metrics: |
| 83 | + value = get_metric_fn(key, "N/A") |
| 84 | + lines.append(f"- **{label}:** {value}{unit}") |
| 85 | + |
| 86 | + return lines |
| 87 | + |
| 88 | + |
| 89 | +def render_task_instructions() -> List[str]: |
| 90 | + """Render the task instructions section for dual-bottleneck analysis.""" |
| 91 | + return [ |
| 92 | + "", |
| 93 | + "## Your Task", |
| 94 | + "", |
| 95 | + "Identify exactly TWO distinct bottlenecks from the NCU profiling metrics above:", |
| 96 | + "1. **Bottleneck 1 (Primary)**: The highest-impact performance issue", |
| 97 | + "2. **Bottleneck 2 (Secondary)**: A different category issue that also limits performance", |
| 98 | + "", |
| 99 | + "For each bottleneck, cite 3-4 specific metrics that reveal the issue, " |
| 100 | + "and recommend ONE actionable optimization.", |
| 101 | + "", |
| 102 | + "**Be surgical and metrics-driven.** Return JSON in the format specified in the system prompt.", |
| 103 | + ] |
| 104 | + |
| 105 | + |
| 106 | +def create_metric_getter(kernel_metrics: Dict[str, Any]) -> Callable[[str, str], str]: |
| 107 | + """Create a metric getter function for a specific kernel's metrics.""" |
| 108 | + |
| 109 | + def get_metric(key: str, default: str = "N/A") -> str: |
| 110 | + val = kernel_metrics.get(key, default) |
| 111 | + if isinstance(val, (int, float)): |
| 112 | + return f"{val:.2f}" |
| 113 | + return str(val) |
| 114 | + |
| 115 | + return get_metric |
| 116 | + |
| 117 | + |
| 118 | +# ============================================================================= |
| 119 | +# Bottleneck Analysis |
| 120 | +# ============================================================================= |
45 | 121 |
|
46 | 122 |
|
47 | 123 | # System prompt for the Judge LLM (Dual-Bottleneck NCU Analysis) |
@@ -228,114 +304,27 @@ def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]: |
228 | 304 |
|
229 | 305 |
|
230 | 306 | def validate_judge_response(analysis: Dict[str, Any]) -> bool: |
231 | | - """ |
232 | | - Validate that Judge response contains required fields for dual-bottleneck format. |
233 | | -
|
234 | | - This function validates the dual-bottleneck format with bottleneck_1 and |
235 | | - bottleneck_2 fields. Both bottlenecks use NCU hardware profiling categories. |
236 | | -
|
237 | | - Args: |
238 | | - analysis: Parsed JSON from Judge response |
239 | | -
|
240 | | - Returns: |
241 | | - True if response is valid, False otherwise |
242 | | -
|
243 | | - Example: |
244 | | - >>> if validate_judge_response(analysis): |
245 | | - ... print("Valid dual-bottleneck response!") |
246 | | - ... else: |
247 | | - ... print("Invalid response - missing required fields") |
248 | | - """ |
249 | | - # Check for dual-bottleneck format |
250 | | - if "bottleneck_1" in analysis and "bottleneck_2" in analysis: |
251 | | - return _validate_bottleneck_entry( |
252 | | - analysis["bottleneck_1"] |
253 | | - ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) |
254 | | - |
255 | | - # Backward compatibility: Check for old single-bottleneck format |
256 | | - if "bottleneck" in analysis: |
257 | | - required_fields = [ |
258 | | - "bottleneck", |
259 | | - "root_cause", |
260 | | - "suggestion", |
261 | | - "priority_metrics", |
262 | | - "expected_improvement", |
263 | | - ] |
264 | | - |
265 | | - for field in required_fields: |
266 | | - if field not in analysis: |
267 | | - return False |
268 | | - |
269 | | - valid_bottlenecks = [ |
270 | | - "memory-bound", |
271 | | - "compute-bound", |
272 | | - "occupancy-limited", |
273 | | - "latency-bound", |
274 | | - ] |
275 | | - if analysis["bottleneck"] not in valid_bottlenecks: |
276 | | - return False |
277 | | - |
278 | | - if not isinstance(analysis["priority_metrics"], list): |
279 | | - return False |
280 | | - |
281 | | - for field in ["root_cause", "suggestion", "expected_improvement"]: |
282 | | - if ( |
283 | | - not isinstance(analysis[field], str) |
284 | | - or len(analysis[field].strip()) < 10 |
285 | | - ): |
286 | | - return False |
| 307 | + """Validate that Judge response contains required dual-bottleneck fields.""" |
| 308 | + if "bottleneck_1" not in analysis or "bottleneck_2" not in analysis: |
| 309 | + return False |
| 310 | + return _validate_bottleneck_entry( |
| 311 | + analysis["bottleneck_1"] |
| 312 | + ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) |
287 | 313 |
|
288 | | - return True |
289 | 314 |
|
290 | | - return False |
| 315 | +VALID_CATEGORIES = frozenset(["memory-bound", "compute-bound", "occupancy-limited", "latency-bound"]) |
291 | 316 |
|
292 | 317 |
|
293 | 318 | def _validate_bottleneck_entry(bottleneck: Dict[str, Any]) -> bool: |
294 | | - """ |
295 | | - Validate a single bottleneck entry (bottleneck_1 or bottleneck_2). |
296 | | -
|
297 | | - Both bottlenecks use NCU hardware profiling categories: |
298 | | - memory-bound, compute-bound, occupancy-limited, latency-bound |
299 | | -
|
300 | | - Args: |
301 | | - bottleneck: Bottleneck dictionary to validate |
302 | | -
|
303 | | - Returns: |
304 | | - True if valid, False otherwise |
305 | | - """ |
306 | | - required_fields = [ |
307 | | - "category", |
308 | | - "root_cause", |
309 | | - "suggestion", |
310 | | - "priority_metrics", |
311 | | - "expected_improvement", |
312 | | - ] |
313 | | - |
314 | | - for field in required_fields: |
315 | | - if field not in bottleneck: |
316 | | - return False |
317 | | - |
318 | | - # NCU hardware profiling categories only |
319 | | - valid_categories = [ |
320 | | - "memory-bound", |
321 | | - "compute-bound", |
322 | | - "occupancy-limited", |
323 | | - "latency-bound", |
324 | | - ] |
325 | | - |
326 | | - if bottleneck["category"] not in valid_categories: |
| 319 | + """Validate a single bottleneck entry.""" |
| 320 | + required = ["category", "root_cause", "suggestion", "priority_metrics", "expected_improvement"] |
| 321 | + if not all(f in bottleneck for f in required): |
| 322 | + return False |
| 323 | + if bottleneck["category"] not in VALID_CATEGORIES: |
327 | 324 | return False |
328 | | - |
329 | 325 | if not isinstance(bottleneck["priority_metrics"], list): |
330 | 326 | return False |
331 | | - |
332 | | - for field in ["root_cause", "suggestion", "expected_improvement"]: |
333 | | - if not isinstance(bottleneck[field], str) or len(bottleneck[field].strip()) < 5: |
| 327 | + for f in ["root_cause", "suggestion", "expected_improvement"]: |
| 328 | + if not isinstance(bottleneck[f], str) or len(bottleneck[f].strip()) < 5: |
334 | 329 | return False |
335 | | - |
336 | 330 | return True |
337 | | - |
338 | | - |
339 | | -if __name__ == "__main__": |
340 | | - print("Judge Prompts Module") |
341 | | - print("=" * 60) |
0 commit comments