|
| 1 | +{# |
| 2 | +Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +#} |
| 16 | + |
| 17 | +TASK: Optimize the following Triton kernel based on hardware profiling analysis to achieve better performance. |
| 18 | + |
| 19 | +{% if gpu_specs %} |
| 20 | +TARGET GPU: |
| 21 | +{% if gpu_specs.name %}- GPU: {{ gpu_specs.name }} |
| 22 | +{% endif %} |
| 23 | +{% if gpu_specs.architecture %}- Architecture: {{ gpu_specs.architecture }} |
| 24 | +{% endif %} |
| 25 | +{% if gpu_specs.peak_memory_bw_gbps %}- Peak Memory Bandwidth: {{ gpu_specs.peak_memory_bw_gbps }} GB/s |
| 26 | +{% endif %} |
| 27 | +{% if gpu_specs.peak_fp32_tflops %}- Peak FP32: {{ gpu_specs.peak_fp32_tflops }} TFLOPS |
| 28 | +{% endif %} |
| 29 | +{% if gpu_specs.peak_fp16_tflops %}- Peak FP16: {{ gpu_specs.peak_fp16_tflops }} TFLOPS |
| 30 | +{% endif %} |
| 31 | +{% if gpu_specs.peak_bf16_tflops %}- Peak BF16: {{ gpu_specs.peak_bf16_tflops }} TFLOPS |
| 32 | +{% endif %} |
| 33 | +{% if gpu_specs.sm_count %}- SM Count: {{ gpu_specs.sm_count }} |
| 34 | +{% endif %} |
| 35 | +{% if gpu_specs.max_threads_per_sm %}- Max Threads per SM: {{ gpu_specs.max_threads_per_sm }} |
| 36 | +{% endif %} |
| 37 | +{% if gpu_specs.l1_cache_kb %}- L1 Cache per SM: {{ gpu_specs.l1_cache_kb }} KB |
| 38 | +{% endif %} |
| 39 | +{% if gpu_specs.l2_cache_mb %}- L2 Cache (Total): {{ gpu_specs.l2_cache_mb }} MB |
| 40 | +{% endif %} |
| 41 | +{% if gpu_specs.memory_gb %}- Memory: {{ gpu_specs.memory_gb }} GB {{ gpu_specs.memory_type | default('') }} |
| 42 | +{% endif %} |
| 43 | + |
| 44 | +{% endif %} |
| 45 | +PROBLEM DESCRIPTION: |
| 46 | +{{ problem_description }} |
| 47 | +{% if pytorch_baseline_ms %} |
| 48 | +PyTorch Eager baseline: {{ "%.4f"|format(pytorch_baseline_ms) }} ms |
| 49 | +{% endif %} |
| 50 | + |
| 51 | +CURRENT KERNEL IMPLEMENTATION: |
| 52 | +```python |
| 53 | +{{ kernel_code }} |
| 54 | +``` |
| 55 | + |
| 56 | +OPTIMIZATION STRATEGY ({{ bottleneck_label }}): |
| 57 | +The hardware profiling (NCU) analysis identified the following bottleneck: |
| 58 | +- Category: {{ bottleneck.category | default('unknown') }} |
| 59 | +- Root Cause: {{ bottleneck.root_cause | default('N/A') }} |
| 60 | +- Suggested Optimization: {{ bottleneck.suggestion | default('N/A') }} |
| 61 | +- Expected Improvement: {{ bottleneck.expected_improvement | default('N/A') }} |
| 62 | + |
| 63 | +{% if error_feedback %} |
| 64 | +PREVIOUS ATTEMPT FAILED: |
| 65 | +{{ error_feedback }} |
| 66 | + |
| 67 | +{% endif %} |
| 68 | +PERFORMANCE TARGET: |
| 69 | +{% if target_ms %} |
| 70 | +- Achieve at least 1.25x speedup vs PyTorch Eager (target: <= {{ "%.4f"|format(target_ms) }} ms) |
| 71 | +{% else %} |
| 72 | +- Achieve 20-100% performance improvement over baseline |
| 73 | +{% endif %} |
| 74 | +- Maintain numerical correctness (atol=1e-4 or rtol=1e-4) |
| 75 | +- Preserve public API (same inputs/outputs, shapes, dtypes) |
| 76 | + |
| 77 | +CRITICAL REQUIREMENTS: |
| 78 | +1. Apply the optimization strategy described above to address the identified bottleneck |
| 79 | +2. The implementation must be a complete, valid Python file |
| 80 | +3. The main function must be named 'kernel_function' that wraps the actual Triton kernel |
| 81 | +4. Focus on the specific optimization while maintaining correctness |
| 82 | +5. Keep the wrapper free of PyTorch compute primitives |
| 83 | + |
| 84 | +OUTPUT FORMAT: |
| 85 | +1. Output complete optimized kernel code in ```python blocks |
| 86 | +2. Include only: imports, Triton kernel (@triton.jit), wrapper function (kernel_function) |
| 87 | +3. No testing code, benchmarks, or explanatory comments |
| 88 | + |
| 89 | +Generate the complete optimized kernel implementation: |
0 commit comments