Skip to content

Commit 2937faa

Browse files
author
Kaiming Cheng
committed
optimization prompt
1 parent a4bff8d commit 2937faa

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

triton_kernel_agent/prompt_manager.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Prompt Manager for handling Jinja2 templates."""
1616

1717
from pathlib import Path
18+
from typing import Any
1819

1920
from triton_kernel_agent.platform_config import PlatformConfig, get_platform
2021

@@ -88,6 +89,7 @@ def _load_templates(self):
8889
"test_generation": "test_generation.j2",
8990
"kernel_generation": "kernel_generation.j2",
9091
"kernel_refinement": "kernel_refinement.j2",
92+
"kernel_optimization": "kernel_optimization.j2",
9193
"triton_guidelines": "triton_guidelines.j2",
9294
}
9395

@@ -188,6 +190,59 @@ def render_kernel_refinement_prompt(
188190
kernel_guidance=self.target_platform.kernel_guidance,
189191
)
190192

193+
def render_kernel_optimization_prompt(
194+
self,
195+
kernel_code: str,
196+
problem_description: str,
197+
bottleneck_analysis: dict[str, Any],
198+
bottleneck_id: int = 1,
199+
gpu_specs: dict[str, Any] | None = None,
200+
pytorch_baseline_ms: float | None = None,
201+
error_feedback: str | None = None,
202+
) -> str:
203+
"""
204+
Render the kernel optimization prompt based on bottleneck analysis.
205+
206+
Args:
207+
kernel_code: Current kernel code to optimize
208+
problem_description: Problem description
209+
bottleneck_analysis: Dual-bottleneck analysis with bottleneck_1 and bottleneck_2
210+
bottleneck_id: Which bottleneck to focus on (1 or 2)
211+
gpu_specs: GPU specifications dict
212+
pytorch_baseline_ms: PyTorch baseline time in ms
213+
error_feedback: Error feedback from previous failed attempt
214+
215+
Returns:
216+
Rendered prompt string
217+
"""
218+
template = self.templates["kernel_optimization"]
219+
220+
# Select bottleneck
221+
if bottleneck_id == 2:
222+
bottleneck = bottleneck_analysis.get("bottleneck_2", {})
223+
bottleneck_label = "Bottleneck 2 (Secondary)"
224+
else:
225+
bottleneck = bottleneck_analysis.get("bottleneck_1", {})
226+
bottleneck_label = "Bottleneck 1 (Primary)"
227+
228+
# Calculate target time if baseline provided
229+
target_ms = None
230+
if pytorch_baseline_ms and pytorch_baseline_ms != float("inf"):
231+
target_ms = pytorch_baseline_ms * 0.8
232+
233+
return template.render(
234+
kernel_code=kernel_code,
235+
problem_description=problem_description,
236+
bottleneck=bottleneck,
237+
bottleneck_label=bottleneck_label,
238+
gpu_specs=gpu_specs,
239+
pytorch_baseline_ms=pytorch_baseline_ms
240+
if pytorch_baseline_ms != float("inf")
241+
else None,
242+
target_ms=target_ms,
243+
error_feedback=error_feedback,
244+
)
245+
191246
def render_triton_guidelines(self) -> str:
192247
"""
193248
Render the Triton guidelines.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
{#
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
#}
16+
17+
TASK: Optimize the following Triton kernel based on hardware profiling analysis to achieve better performance.
18+
19+
{% if gpu_specs %}
20+
TARGET GPU:
21+
{% if gpu_specs.name %}- GPU: {{ gpu_specs.name }}
22+
{% endif %}
23+
{% if gpu_specs.architecture %}- Architecture: {{ gpu_specs.architecture }}
24+
{% endif %}
25+
{% if gpu_specs.peak_memory_bw_gbps %}- Peak Memory Bandwidth: {{ gpu_specs.peak_memory_bw_gbps }} GB/s
26+
{% endif %}
27+
{% if gpu_specs.peak_fp32_tflops %}- Peak FP32: {{ gpu_specs.peak_fp32_tflops }} TFLOPS
28+
{% endif %}
29+
{% if gpu_specs.peak_fp16_tflops %}- Peak FP16: {{ gpu_specs.peak_fp16_tflops }} TFLOPS
30+
{% endif %}
31+
{% if gpu_specs.peak_bf16_tflops %}- Peak BF16: {{ gpu_specs.peak_bf16_tflops }} TFLOPS
32+
{% endif %}
33+
{% if gpu_specs.sm_count %}- SM Count: {{ gpu_specs.sm_count }}
34+
{% endif %}
35+
{% if gpu_specs.max_threads_per_sm %}- Max Threads per SM: {{ gpu_specs.max_threads_per_sm }}
36+
{% endif %}
37+
{% if gpu_specs.l1_cache_kb %}- L1 Cache per SM: {{ gpu_specs.l1_cache_kb }} KB
38+
{% endif %}
39+
{% if gpu_specs.l2_cache_mb %}- L2 Cache (Total): {{ gpu_specs.l2_cache_mb }} MB
40+
{% endif %}
41+
{% if gpu_specs.memory_gb %}- Memory: {{ gpu_specs.memory_gb }} GB {{ gpu_specs.memory_type | default('') }}
42+
{% endif %}
43+
44+
{% endif %}
45+
PROBLEM DESCRIPTION:
46+
{{ problem_description }}
47+
{% if pytorch_baseline_ms %}
48+
PyTorch Eager baseline: {{ "%.4f"|format(pytorch_baseline_ms) }} ms
49+
{% endif %}
50+
51+
CURRENT KERNEL IMPLEMENTATION:
52+
```python
53+
{{ kernel_code }}
54+
```
55+
56+
OPTIMIZATION STRATEGY ({{ bottleneck_label }}):
57+
The hardware profiling (NCU) analysis identified the following bottleneck:
58+
- Category: {{ bottleneck.category | default('unknown') }}
59+
- Root Cause: {{ bottleneck.root_cause | default('N/A') }}
60+
- Suggested Optimization: {{ bottleneck.suggestion | default('N/A') }}
61+
- Expected Improvement: {{ bottleneck.expected_improvement | default('N/A') }}
62+
63+
{% if error_feedback %}
64+
PREVIOUS ATTEMPT FAILED:
65+
{{ error_feedback }}
66+
67+
{% endif %}
68+
PERFORMANCE TARGET:
69+
{% if target_ms %}
70+
- Achieve at least 1.25x speedup vs PyTorch Eager (target: <= {{ "%.4f"|format(target_ms) }} ms)
71+
{% else %}
72+
- Achieve 20-100% performance improvement over baseline
73+
{% endif %}
74+
- Maintain numerical correctness (atol=1e-4 or rtol=1e-4)
75+
- Preserve public API (same inputs/outputs, shapes, dtypes)
76+
77+
CRITICAL REQUIREMENTS:
78+
1. Apply the optimization strategy described above to address the identified bottleneck
79+
2. The implementation must be a complete, valid Python file
80+
3. The main function must be named 'kernel_function' that wraps the actual Triton kernel
81+
4. Focus on the specific optimization while maintaining correctness
82+
5. Keep the wrapper free of PyTorch compute primitives
83+
84+
OUTPUT FORMAT:
85+
1. Output complete optimized kernel code in ```python blocks
86+
2. Include only: imports, Triton kernel (@triton.jit), wrapper function (kernel_function)
87+
3. No testing code, benchmarks, or explanatory comments
88+
89+
Generate the complete optimized kernel implementation:

0 commit comments

Comments
 (0)