3232Metric definitions are in metric_schema.py.
3333"""
3434
35- from typing import Any , Callable , Dict , List , Optional , Tuple
35+ from typing import Any , Callable
3636
3737from .metric_schema import GPU_MEMORY_FIELDS , GPU_SPEC_FIELDS , NCU_METRIC_SECTIONS
3838
4242# =============================================================================
4343
4444
45- def render_problem_description (problem_description : str ) -> List [str ]:
45+ def render_problem_description (problem_description : str ) -> list [str ]:
4646 """Render the problem description section."""
4747 return ["## Problem Description" , "" , problem_description ]
4848
4949
50- def render_kernel_code (kernel_code : str , language : str = "python" ) -> List [str ]:
50+ def render_kernel_code (kernel_code : str , language : str = "python" ) -> list [str ]:
5151 """Render the kernel code section with syntax highlighting."""
5252 return ["" , "## Current Kernel Code" , "" , f"```{ language } " , kernel_code , "```" ]
5353
5454
55- def render_gpu_specs (gpu_specs : Dict [str , Any ]) -> List [str ]:
55+ def render_gpu_specs (gpu_specs : dict [str , Any ]) -> list [str ]:
5656 """Render the GPU hardware specifications section."""
5757 lines = ["" , "## GPU Hardware Specifications" , "" ]
5858
5959 for label , key , unit in GPU_SPEC_FIELDS :
6060 value = gpu_specs .get (key , "N/A" )
61- lines .append (f"- **{ label } :** { value } { unit } " )
61+ if value == "N/A" :
62+ lines .append (f"- **{ label } :** N/A" )
63+ else :
64+ lines .append (f"- **{ label } :** { value } { unit } " )
6265
6366 for label , size_key , type_key , size_unit in GPU_MEMORY_FIELDS :
6467 size_value = gpu_specs .get (size_key , "N/A" )
@@ -69,9 +72,9 @@ def render_gpu_specs(gpu_specs: Dict[str, Any]) -> List[str]:
6972
7073
7174def render_ncu_metrics (
72- ncu_metrics : Dict [str , Any ],
75+ ncu_metrics : dict [str , Any ],
7376 get_metric_fn : Callable [[str , str ], str ],
74- ) -> List [str ]:
77+ ) -> list [str ]:
7578 """Render the NCU profiling metrics section."""
7679 lines = ["" , "## NCU Profiling Metrics" ]
7780
@@ -85,7 +88,7 @@ def render_ncu_metrics(
8588 return lines
8689
8790
88- def render_task_instructions () -> List [str ]:
91+ def render_task_instructions () -> list [str ]:
8992 """Render the task instructions section for dual-bottleneck analysis."""
9093 return [
9194 "" ,
@@ -102,7 +105,7 @@ def render_task_instructions() -> List[str]:
102105 ]
103106
104107
105- def create_metric_getter (kernel_metrics : Dict [str , Any ]) -> Callable [[str , str ], str ]:
108+ def create_metric_getter (kernel_metrics : dict [str , Any ]) -> Callable [[str , str ], str ]:
106109 """Create a metric getter function for a specific kernel's metrics."""
107110
108111 def get_metric (key : str , default : str = "N/A" ) -> str :
@@ -172,9 +175,9 @@ def get_metric(key: str, default: str = "N/A") -> str:
172175def build_judge_optimization_prompt (
173176 kernel_code : str ,
174177 problem_description : str ,
175- ncu_metrics : Dict [str , Any ],
176- gpu_specs : Dict [str , Any ],
177- ) -> Tuple [str , str ]:
178+ ncu_metrics : dict [str , Any ],
179+ gpu_specs : dict [str , Any ],
180+ ) -> tuple [str , str ]:
178181 """
179182 Build system and user prompts for Judge to analyze bottleneck.
180183
@@ -209,7 +212,7 @@ def build_judge_optimization_prompt(
209212 raise ValueError ("NCU metrics are empty - cannot build judge prompt" )
210213
211214 # Extract first kernel's metrics for the metric getter
212- first_kernel = list (ncu_metrics .values ())[0 ] if ncu_metrics else {}
215+ first_kernel = list (ncu_metrics .values ())[0 ]
213216 get_metric_fn = create_metric_getter (first_kernel )
214217
215218 # Build user prompt using modular section renderers
@@ -226,7 +229,7 @@ def build_judge_optimization_prompt(
226229 return JUDGE_SYSTEM_PROMPT , user_prompt
227230
228231
229- def extract_judge_response (response_text : str ) -> Optional [ Dict [ str , Any ]] :
232+ def extract_judge_response (response_text : str ) -> dict [ str , Any ] | None :
230233 """
231234 Extract and parse JSON from Judge LLM response.
232235
@@ -302,7 +305,7 @@ def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]:
302305 return None
303306
304307
305- def validate_judge_response (analysis : Dict [str , Any ]) -> bool :
308+ def validate_judge_response (analysis : dict [str , Any ]) -> bool :
306309 """Validate that Judge response contains required dual-bottleneck fields."""
307310 if "bottleneck_1" not in analysis or "bottleneck_2" not in analysis :
308311 return False
@@ -311,12 +314,12 @@ def validate_judge_response(analysis: Dict[str, Any]) -> bool:
311314 ) and _validate_bottleneck_entry (analysis ["bottleneck_2" ])
312315
313316
314- VALID_CATEGORIES = frozenset (
315- [ "memory-bound" , "compute-bound" , "occupancy-limited" , "latency-bound" ]
316- )
317+ VALID_CATEGORIES = {
318+ "memory-bound" , "compute-bound" , "occupancy-limited" , "latency-bound"
319+ }
317320
318321
319- def _validate_bottleneck_entry (bottleneck : Dict [str , Any ]) -> bool :
322+ def _validate_bottleneck_entry (bottleneck : dict [str , Any ]) -> bool :
320323 """Validate a single bottleneck entry."""
321324 required = [
322325 "category" ,
0 commit comments