Skip to content

Commit 72ac4d1

Browse files
author
Kaiming Cheng
committed
fix
1 parent e7ba29a commit 72ac4d1

File tree

3 files changed

+116
-85
lines changed

3 files changed

+116
-85
lines changed

kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py

Lines changed: 12 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -22,71 +22,16 @@
2222
"""
2323

2424
import subprocess
25-
from typing import Any, Dict, Optional
26-
27-
# GPU specifications database
28-
# Sources: NVIDIA official specifications, manufacturer datasheets
29-
GPU_SPECS_DATABASE = {
30-
"NVIDIA A100": {
31-
"name": "NVIDIA A100",
32-
"architecture": "Ampere",
33-
"peak_fp32_tflops": 19.5,
34-
"peak_fp16_tflops": 312.0,
35-
"peak_bf16_tflops": 312.0,
36-
"peak_memory_bw_gbps": 1555,
37-
"sm_count": 108,
38-
"max_threads_per_sm": 2048,
39-
"l1_cache_kb": 192,
40-
"l2_cache_mb": 40,
41-
"memory_gb": 40,
42-
"memory_type": "HBM2e",
43-
},
44-
"NVIDIA H100": {
45-
"name": "NVIDIA H100",
46-
"architecture": "Hopper",
47-
"peak_fp32_tflops": 51.0,
48-
"peak_fp16_tflops": 989.0,
49-
"peak_bf16_tflops": 989.0,
50-
"peak_memory_bw_gbps": 3352,
51-
"sm_count": 132,
52-
"max_threads_per_sm": 2048,
53-
"l1_cache_kb": 256,
54-
"l2_cache_mb": 50,
55-
"memory_gb": 80,
56-
"memory_type": "HBM3",
57-
},
58-
"NVIDIA RTX 4090": {
59-
"name": "NVIDIA RTX 4090",
60-
"architecture": "Ada Lovelace",
61-
"peak_fp32_tflops": 82.6,
62-
"peak_fp16_tflops": 165.0,
63-
"peak_bf16_tflops": 165.0,
64-
"peak_memory_bw_gbps": 1008,
65-
"sm_count": 128,
66-
"max_threads_per_sm": 1536,
67-
"l1_cache_kb": 128,
68-
"l2_cache_mb": 72,
69-
"memory_gb": 24,
70-
"memory_type": "GDDR6X",
71-
},
72-
"NVIDIA RTX 5080": {
73-
"name": "NVIDIA RTX 5080",
74-
"architecture": "Blackwell",
75-
"peak_fp32_tflops": 57.0,
76-
"peak_fp16_tflops": 114.0,
77-
"peak_bf16_tflops": 114.0,
78-
"peak_memory_bw_gbps": 960,
79-
"sm_count": 84,
80-
"max_threads_per_sm": 1536,
81-
"l1_cache_kb": 128,
82-
"l2_cache_mb": 64,
83-
"memory_gb": 16,
84-
"memory_type": "GDDR7",
85-
},
86-
}
87-
88-
89-
def query_gpu_name() -> Optional[str]:
25+
from typing import Any
26+
27+
from kernel_perf_agent.kernel_opt.diagnose_prompt.gpu_specs_database import (
28+
GPU_SPECS_DATABASE,
29+
)
30+
31+
__all__ = ["GPU_SPECS_DATABASE", "query_gpu_name", "get_gpu_specs"]
32+
33+
34+
def query_gpu_name() -> str | None:
9035
"""
9136
Query GPU name using nvidia-smi.
9237
@@ -109,7 +54,7 @@ def query_gpu_name() -> Optional[str]:
10954
return None
11055

11156

112-
def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]:
57+
def get_gpu_specs(gpu_name: str | None = None) -> dict[str, Any]:
11358
"""
11459
Get GPU specifications for bottleneck analysis.
11560
@@ -179,6 +124,7 @@ def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]:
179124
print(f"\nDetected GPU: {detected_name}")
180125
else:
181126
print("\nNo GPU detected (nvidia-smi not available)")
127+
exit()
182128

183129
# Get specs
184130
specs = get_gpu_specs()
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
GPU Specifications Database
17+
18+
This module contains the GPU hardware specifications database used for
19+
performance analysis and bottleneck identification. Separated into its
20+
own file to allow easier module overriding.
21+
22+
Sources: NVIDIA official specifications, manufacturer datasheets
23+
"""
24+
25+
GPU_SPECS_DATABASE: dict[str, dict[str, object]] = {
26+
"NVIDIA A100": {
27+
"name": "NVIDIA A100",
28+
"architecture": "Ampere",
29+
"peak_fp32_tflops": 19.5,
30+
"peak_fp16_tflops": 312.0,
31+
"peak_bf16_tflops": 312.0,
32+
"peak_memory_bw_gbps": 1555,
33+
"sm_count": 108,
34+
"max_threads_per_sm": 2048,
35+
"l1_cache_kb": 192,
36+
"l2_cache_mb": 40,
37+
"memory_gb": 40,
38+
"memory_type": "HBM2e",
39+
},
40+
"NVIDIA H100": {
41+
"name": "NVIDIA H100",
42+
"architecture": "Hopper",
43+
"peak_fp32_tflops": 51.0,
44+
"peak_fp16_tflops": 989.0,
45+
"peak_bf16_tflops": 989.0,
46+
"peak_memory_bw_gbps": 3352,
47+
"sm_count": 132,
48+
"max_threads_per_sm": 2048,
49+
"l1_cache_kb": 256,
50+
"l2_cache_mb": 50,
51+
"memory_gb": 80,
52+
"memory_type": "HBM3",
53+
},
54+
"NVIDIA RTX 4090": {
55+
"name": "NVIDIA RTX 4090",
56+
"architecture": "Ada Lovelace",
57+
"peak_fp32_tflops": 82.6,
58+
"peak_fp16_tflops": 165.0,
59+
"peak_bf16_tflops": 165.0,
60+
"peak_memory_bw_gbps": 1008,
61+
"sm_count": 128,
62+
"max_threads_per_sm": 1536,
63+
"l1_cache_kb": 128,
64+
"l2_cache_mb": 72,
65+
"memory_gb": 24,
66+
"memory_type": "GDDR6X",
67+
},
68+
"NVIDIA RTX 5080": {
69+
"name": "NVIDIA RTX 5080",
70+
"architecture": "Blackwell",
71+
"peak_fp32_tflops": 57.0,
72+
"peak_fp16_tflops": 114.0,
73+
"peak_bf16_tflops": 114.0,
74+
"peak_memory_bw_gbps": 960,
75+
"sm_count": 84,
76+
"max_threads_per_sm": 1536,
77+
"l1_cache_kb": 128,
78+
"l2_cache_mb": 64,
79+
"memory_gb": 16,
80+
"memory_type": "GDDR7",
81+
},
82+
}

kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
Metric definitions are in metric_schema.py.
3333
"""
3434

35-
from typing import Any, Callable, Dict, List, Optional, Tuple
35+
from typing import Any, Callable
3636

3737
from .metric_schema import GPU_MEMORY_FIELDS, GPU_SPEC_FIELDS, NCU_METRIC_SECTIONS
3838

@@ -42,23 +42,26 @@
4242
# =============================================================================
4343

4444

45-
def render_problem_description(problem_description: str) -> List[str]:
45+
def render_problem_description(problem_description: str) -> list[str]:
4646
"""Render the problem description section."""
4747
return ["## Problem Description", "", problem_description]
4848

4949

50-
def render_kernel_code(kernel_code: str, language: str = "python") -> List[str]:
50+
def render_kernel_code(kernel_code: str, language: str = "python") -> list[str]:
5151
"""Render the kernel code section with syntax highlighting."""
5252
return ["", "## Current Kernel Code", "", f"```{language}", kernel_code, "```"]
5353

5454

55-
def render_gpu_specs(gpu_specs: Dict[str, Any]) -> List[str]:
55+
def render_gpu_specs(gpu_specs: dict[str, Any]) -> list[str]:
5656
"""Render the GPU hardware specifications section."""
5757
lines = ["", "## GPU Hardware Specifications", ""]
5858

5959
for label, key, unit in GPU_SPEC_FIELDS:
6060
value = gpu_specs.get(key, "N/A")
61-
lines.append(f"- **{label}:** {value}{unit}")
61+
if value == "N/A":
62+
lines.append(f"- **{label}:** N/A")
63+
else:
64+
lines.append(f"- **{label}:** {value}{unit}")
6265

6366
for label, size_key, type_key, size_unit in GPU_MEMORY_FIELDS:
6467
size_value = gpu_specs.get(size_key, "N/A")
@@ -69,9 +72,9 @@ def render_gpu_specs(gpu_specs: Dict[str, Any]) -> List[str]:
6972

7073

7174
def render_ncu_metrics(
72-
ncu_metrics: Dict[str, Any],
75+
ncu_metrics: dict[str, Any],
7376
get_metric_fn: Callable[[str, str], str],
74-
) -> List[str]:
77+
) -> list[str]:
7578
"""Render the NCU profiling metrics section."""
7679
lines = ["", "## NCU Profiling Metrics"]
7780

@@ -85,7 +88,7 @@ def render_ncu_metrics(
8588
return lines
8689

8790

88-
def render_task_instructions() -> List[str]:
91+
def render_task_instructions() -> list[str]:
8992
"""Render the task instructions section for dual-bottleneck analysis."""
9093
return [
9194
"",
@@ -102,7 +105,7 @@ def render_task_instructions() -> List[str]:
102105
]
103106

104107

105-
def create_metric_getter(kernel_metrics: Dict[str, Any]) -> Callable[[str, str], str]:
108+
def create_metric_getter(kernel_metrics: dict[str, Any]) -> Callable[[str, str], str]:
106109
"""Create a metric getter function for a specific kernel's metrics."""
107110

108111
def get_metric(key: str, default: str = "N/A") -> str:
@@ -172,9 +175,9 @@ def get_metric(key: str, default: str = "N/A") -> str:
172175
def build_judge_optimization_prompt(
173176
kernel_code: str,
174177
problem_description: str,
175-
ncu_metrics: Dict[str, Any],
176-
gpu_specs: Dict[str, Any],
177-
) -> Tuple[str, str]:
178+
ncu_metrics: dict[str, Any],
179+
gpu_specs: dict[str, Any],
180+
) -> tuple[str, str]:
178181
"""
179182
Build system and user prompts for Judge to analyze bottleneck.
180183
@@ -209,7 +212,7 @@ def build_judge_optimization_prompt(
209212
raise ValueError("NCU metrics are empty - cannot build judge prompt")
210213

211214
# Extract first kernel's metrics for the metric getter
212-
first_kernel = list(ncu_metrics.values())[0] if ncu_metrics else {}
215+
first_kernel = list(ncu_metrics.values())[0]
213216
get_metric_fn = create_metric_getter(first_kernel)
214217

215218
# Build user prompt using modular section renderers
@@ -226,7 +229,7 @@ def build_judge_optimization_prompt(
226229
return JUDGE_SYSTEM_PROMPT, user_prompt
227230

228231

229-
def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]:
232+
def extract_judge_response(response_text: str) -> dict[str, Any] | None:
230233
"""
231234
Extract and parse JSON from Judge LLM response.
232235
@@ -302,7 +305,7 @@ def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]:
302305
return None
303306

304307

305-
def validate_judge_response(analysis: Dict[str, Any]) -> bool:
308+
def validate_judge_response(analysis: dict[str, Any]) -> bool:
306309
"""Validate that Judge response contains required dual-bottleneck fields."""
307310
if "bottleneck_1" not in analysis or "bottleneck_2" not in analysis:
308311
return False
@@ -311,12 +314,12 @@ def validate_judge_response(analysis: Dict[str, Any]) -> bool:
311314
) and _validate_bottleneck_entry(analysis["bottleneck_2"])
312315

313316

314-
VALID_CATEGORIES = frozenset(
315-
["memory-bound", "compute-bound", "occupancy-limited", "latency-bound"]
316-
)
317+
VALID_CATEGORIES = {
318+
"memory-bound", "compute-bound", "occupancy-limited", "latency-bound"
319+
}
317320

318321

319-
def _validate_bottleneck_entry(bottleneck: Dict[str, Any]) -> bool:
322+
def _validate_bottleneck_entry(bottleneck: dict[str, Any]) -> bool:
320323
"""Validate a single bottleneck entry."""
321324
required = [
322325
"category",

0 commit comments

Comments
 (0)