Skip to content

Commit 9e27818

Browse files
Kaiming Chengkaiming-cheng
authored andcommitted
fix diff issue
1 parent 10b0f5d commit 9e27818

File tree

1 file changed

+100
-111
lines changed

1 file changed

+100
-111
lines changed

kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py

Lines changed: 100 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,95 @@
2929
- occupancy-limited
3030
- latency-bound
3131
32-
Metric definitions are in metric_schema.py and rendering logic is in section_renderers.py.
32+
Metric definitions are in metric_schema.py.
3333
"""
3434

35-
from typing import Any, Dict, Optional, Tuple
35+
from typing import Any, Callable, Dict, List, Optional, Tuple
3636

37-
from .section_renderers import (
38-
create_metric_getter,
39-
render_gpu_specs,
40-
render_kernel_code,
41-
render_ncu_metrics,
42-
render_problem_description,
43-
render_task_instructions,
44-
)
37+
from .metric_schema import GPU_MEMORY_FIELDS, GPU_SPEC_FIELDS, NCU_METRIC_SECTIONS
38+
39+
40+
41+
# =============================================================================
42+
# Section Renderers
43+
# =============================================================================
44+
45+
46+
def render_problem_description(problem_description: str) -> List[str]:
47+
"""Render the problem description section."""
48+
return ["## Problem Description", "", problem_description]
49+
50+
51+
def render_kernel_code(kernel_code: str, language: str = "python") -> List[str]:
52+
"""Render the kernel code section with syntax highlighting."""
53+
return ["", "## Current Kernel Code", "", f"```{language}", kernel_code, "```"]
54+
55+
56+
def render_gpu_specs(gpu_specs: Dict[str, Any]) -> List[str]:
57+
"""Render the GPU hardware specifications section."""
58+
lines = ["", "## GPU Hardware Specifications", ""]
59+
60+
for label, key, unit in GPU_SPEC_FIELDS:
61+
value = gpu_specs.get(key, "N/A")
62+
lines.append(f"- **{label}:** {value}{unit}")
63+
64+
for label, size_key, type_key, size_unit in GPU_MEMORY_FIELDS:
65+
size_value = gpu_specs.get(size_key, "N/A")
66+
type_value = gpu_specs.get(type_key, "")
67+
lines.append(f"- **{label}:** {size_value}{size_unit} {type_value}")
68+
69+
return lines
70+
71+
72+
def render_ncu_metrics(
73+
ncu_metrics: Dict[str, Any],
74+
get_metric_fn: Callable[[str, str], str],
75+
) -> List[str]:
76+
"""Render the NCU profiling metrics section."""
77+
lines = ["", "## NCU Profiling Metrics"]
78+
79+
for section_name, metrics in NCU_METRIC_SECTIONS.items():
80+
lines.append("")
81+
lines.append(f"### {section_name}")
82+
for label, key, unit in metrics:
83+
value = get_metric_fn(key, "N/A")
84+
lines.append(f"- **{label}:** {value}{unit}")
85+
86+
return lines
87+
88+
89+
def render_task_instructions() -> List[str]:
90+
"""Render the task instructions section for dual-bottleneck analysis."""
91+
return [
92+
"",
93+
"## Your Task",
94+
"",
95+
"Identify exactly TWO distinct bottlenecks from the NCU profiling metrics above:",
96+
"1. **Bottleneck 1 (Primary)**: The highest-impact performance issue",
97+
"2. **Bottleneck 2 (Secondary)**: A different category issue that also limits performance",
98+
"",
99+
"For each bottleneck, cite 3-4 specific metrics that reveal the issue, "
100+
"and recommend ONE actionable optimization.",
101+
"",
102+
"**Be surgical and metrics-driven.** Return JSON in the format specified in the system prompt.",
103+
]
104+
105+
106+
def create_metric_getter(kernel_metrics: Dict[str, Any]) -> Callable[[str, str], str]:
107+
"""Create a metric getter function for a specific kernel's metrics."""
108+
109+
def get_metric(key: str, default: str = "N/A") -> str:
110+
val = kernel_metrics.get(key, default)
111+
if isinstance(val, (int, float)):
112+
return f"{val:.2f}"
113+
return str(val)
114+
115+
return get_metric
116+
117+
118+
# =============================================================================
119+
# Bottleneck Analysis
120+
# =============================================================================
45121

46122

47123
# System prompt for the Judge LLM (Dual-Bottleneck NCU Analysis)
@@ -228,114 +304,27 @@ def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]:
228304

229305

230306
def validate_judge_response(analysis: Dict[str, Any]) -> bool:
231-
"""
232-
Validate that Judge response contains required fields for dual-bottleneck format.
233-
234-
This function validates the dual-bottleneck format with bottleneck_1 and
235-
bottleneck_2 fields. Both bottlenecks use NCU hardware profiling categories.
236-
237-
Args:
238-
analysis: Parsed JSON from Judge response
239-
240-
Returns:
241-
True if response is valid, False otherwise
242-
243-
Example:
244-
>>> if validate_judge_response(analysis):
245-
... print("Valid dual-bottleneck response!")
246-
... else:
247-
... print("Invalid response - missing required fields")
248-
"""
249-
# Check for dual-bottleneck format
250-
if "bottleneck_1" in analysis and "bottleneck_2" in analysis:
251-
return _validate_bottleneck_entry(
252-
analysis["bottleneck_1"]
253-
) and _validate_bottleneck_entry(analysis["bottleneck_2"])
254-
255-
# Backward compatibility: Check for old single-bottleneck format
256-
if "bottleneck" in analysis:
257-
required_fields = [
258-
"bottleneck",
259-
"root_cause",
260-
"suggestion",
261-
"priority_metrics",
262-
"expected_improvement",
263-
]
264-
265-
for field in required_fields:
266-
if field not in analysis:
267-
return False
268-
269-
valid_bottlenecks = [
270-
"memory-bound",
271-
"compute-bound",
272-
"occupancy-limited",
273-
"latency-bound",
274-
]
275-
if analysis["bottleneck"] not in valid_bottlenecks:
276-
return False
277-
278-
if not isinstance(analysis["priority_metrics"], list):
279-
return False
280-
281-
for field in ["root_cause", "suggestion", "expected_improvement"]:
282-
if (
283-
not isinstance(analysis[field], str)
284-
or len(analysis[field].strip()) < 10
285-
):
286-
return False
307+
"""Validate that Judge response contains required dual-bottleneck fields."""
308+
if "bottleneck_1" not in analysis or "bottleneck_2" not in analysis:
309+
return False
310+
return _validate_bottleneck_entry(
311+
analysis["bottleneck_1"]
312+
) and _validate_bottleneck_entry(analysis["bottleneck_2"])
287313

288-
return True
289314

290-
return False
315+
VALID_CATEGORIES = frozenset(["memory-bound", "compute-bound", "occupancy-limited", "latency-bound"])
291316

292317

293318
def _validate_bottleneck_entry(bottleneck: Dict[str, Any]) -> bool:
294-
"""
295-
Validate a single bottleneck entry (bottleneck_1 or bottleneck_2).
296-
297-
Both bottlenecks use NCU hardware profiling categories:
298-
memory-bound, compute-bound, occupancy-limited, latency-bound
299-
300-
Args:
301-
bottleneck: Bottleneck dictionary to validate
302-
303-
Returns:
304-
True if valid, False otherwise
305-
"""
306-
required_fields = [
307-
"category",
308-
"root_cause",
309-
"suggestion",
310-
"priority_metrics",
311-
"expected_improvement",
312-
]
313-
314-
for field in required_fields:
315-
if field not in bottleneck:
316-
return False
317-
318-
# NCU hardware profiling categories only
319-
valid_categories = [
320-
"memory-bound",
321-
"compute-bound",
322-
"occupancy-limited",
323-
"latency-bound",
324-
]
325-
326-
if bottleneck["category"] not in valid_categories:
319+
"""Validate a single bottleneck entry."""
320+
required = ["category", "root_cause", "suggestion", "priority_metrics", "expected_improvement"]
321+
if not all(f in bottleneck for f in required):
322+
return False
323+
if bottleneck["category"] not in VALID_CATEGORIES:
327324
return False
328-
329325
if not isinstance(bottleneck["priority_metrics"], list):
330326
return False
331-
332-
for field in ["root_cause", "suggestion", "expected_improvement"]:
333-
if not isinstance(bottleneck[field], str) or len(bottleneck[field].strip()) < 5:
327+
for f in ["root_cause", "suggestion", "expected_improvement"]:
328+
if not isinstance(bottleneck[f], str) or len(bottleneck[f].strip()) < 5:
334329
return False
335-
336330
return True
337-
338-
339-
if __name__ == "__main__":
340-
print("Judge Prompts Module")
341-
print("=" * 60)

0 commit comments

Comments
 (0)