diff --git a/oink/README.md b/oink/README.md new file mode 100644 index 0000000..aeb0c09 --- /dev/null +++ b/oink/README.md @@ -0,0 +1,109 @@ +# KernelAgent-Oink + +KernelAgent-Oink is a small **CuTeDSL (CUTLASS DSL) kernel library** for +**NVIDIA Blackwell (SM100 / GB200 / B200-class)**, bundled as a lightweight +Python package that can be used standalone or as a **vLLM general plugin**. + +At the moment, the vLLM integration exposes the following `torch.library.custom_op` +entrypoints under the `oink::` namespace: + +- `torch.ops.oink.rmsnorm(x, weight, eps) -> Tensor` +- `torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps) -> None` (in-place) + +The package also includes additional SM100 kernels used by the benchmark suite: +LayerNorm, Softmax (fwd+bwd), and CrossEntropy (fwd+bwd). + +## Requirements + +- GPU: **SM100** for the fast CuTeDSL paths. On other GPUs, Oink falls back to + reference PyTorch implementations for correctness. +- Python dependencies: + - `nvidia-cutlass-dsl` (CuTeDSL) + - `cuda-python` + - `torch` (provided by your environment / vLLM) + +Recommended env vars: + +```bash +export CUTE_DSL_ARCH=sm_100a +export PYTORCH_ALLOC_CONF=expandable_segments:True +``` + +## Install (editable) + +From the `KernelAgent` repo root: + +```bash +pip install -e ./oink +``` + +For running the in-repo benchmark suite / plots: + +```bash +pip install -e "./oink[bench]" +``` + +## Usage + +### vLLM (general plugin) + +1) Enable the plugin: + +```bash +export VLLM_USE_OINK_RMSNORM=1 +``` + +2) Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / CUDA graphs: + +```python +from vllm import LLM + +llm = LLM( + model=..., + tensor_parallel_size=..., + enforce_eager=False, + compilation_config={"custom_ops": ["none", "+rms_norm"]}, +) +``` + +Without `+rms_norm`, Inductor may fuse RMSNorm into larger kernels and neither +vLLM’s CUDA RMSNorm nor Oink will run. + +### Direct PyTorch usage (manual op registration) + +For standalone use (outside vLLM), register the custom ops once: + +```python +import kernelagent_oink +import torch + +kernelagent_oink.register(force=True) + +x = torch.randn(1024, 4096, device="cuda", dtype=torch.bfloat16) +w = torch.randn(4096, device="cuda", dtype=torch.bfloat16) +y = torch.ops.oink.rmsnorm(x, w, 1e-6) +``` + +## Benchmarks + +The repo includes a Quack-style benchmark suite (tables + SVG plots) to compare +Oink against Quack on SM100 and to reproduce the reported speedups. + +- How to run + methodology: `oink/benchmarks/README.md` +- Pre-generated plots: `oink/benchmarks/media/` + +
+ SM100 BF16: Oink vs Quack (Quack-suite) +
+ +
+ SM100 BF16: Oink vs Quack (DSv3-like shapes) +
+ +## Links + +| What | Link | +|---|---| +| Quack (expert baseline) | https://github.com/Dao-AILab/quack | +| KernelAgent (agentic framework) | https://github.com/meta-pytorch/KernelAgent | +| vLLM PR (Oink RMSNorm integration) | https://github.com/vllm-project/vllm/pull/31828 | diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md new file mode 100644 index 0000000..a5c4676 --- /dev/null +++ b/oink/benchmarks/README.md @@ -0,0 +1,152 @@ +# SM100 Benchmarks (KernelAgent-Oink vs Quack) + +This folder contains SM100 (GB200 / Blackwell) microbenchmarks for the Oink +CuTeDSL kernels vendored into KernelAgent, comparing against Quack’s SM100 +kernels where Quack provides an equivalent API. + +## Prereqs + +- GPU: **SM100** (`torch.cuda.get_device_capability() == (10, 0)`). +- Python deps in your environment: + - `torch` + - `nvidia-cutlass-dsl` (CuTeDSL) + - `cuda-python` + - `triton` (only for `triton.testing.do_bench`) + - `quack` (optional; only needed for Oink-vs-Quack comparisons) + +Recommended env vars: + +```bash +export PYTORCH_ALLOC_CONF=expandable_segments:True +export CUTE_DSL_ARCH=sm_100a +``` + +## Shape suites + +- **Quack-suite**: `(batch, seq) ∈ {1,4,8,16,32} × {8192,16384,32768,65536,131072}`, + with `hidden = 4096` so `M = batch * seq`, `N = 4096`. +- **DeepSeek-V3-like (DSv3)** + - RMSNorm / LayerNorm / Softmax: `M ∈ {4096, 16384, 65536}`, `N ∈ {6144, 7168, 8192}` + - Cross-entropy: `M ∈ {4096, 16384, 65536}`, `N ∈ {3072, 6144, 8192, 12288}` + +## Correctness gates + +By default, each script runs a per-shape `torch.testing.assert_close` check +vs a **pure-PyTorch reference** **before** emitting timing numbers. When Quack +is available for that op/path, the script also validates Quack vs the *same* +reference (so speedups can’t come from looser numerics). + +Disable with `--skip-verify` only for quick smoke tests. + +## Running benchmarks + +All scripts support: + +- `--quack-suite` or `--dsv3` (or `--configs MxN,...`) +- `--dtype {bf16,fp16,fp32}` +- `--iters ` and `--warmup-ms ` for kernel-only timing +- `--json ` and/or `--csv ` outputs (meta + rows) + +### One-command suite + +Run the full Quack-suite + DSv3 set (Oink vs Quack) and write all JSON artifacts +to a timestamped directory: + +```bash +python oink/benchmarks/readme/run_sm100_suite.py --dtype bf16 +``` + +Turn the JSON artifacts into Markdown tables (with geomean speedups): + +```bash +python oink/benchmarks/readme/summarize_results.py --in-dir /tmp/kernelagent_oink_sm100_suite_ \ + --out /tmp/kernelagent_oink_sm100_suite_summary.md +``` + +### Measured HBM roofline (STREAM-like) + +To contextualize the `*_tbps` numbers as a fraction of a *measured* bandwidth +ceiling (rather than a theoretical spec), run: + +```bash +CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 2 \ + --json /tmp/hbm_roofline_sm100_bf16.json +``` + +### RMSNorm forward + +```bash +python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \ + --json /tmp/oink_rmsnorm_fwd_quack_suite.json + +python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \ + --json /tmp/oink_rmsnorm_fwd_dsv3.json + +# vLLM-style inference weights (weight dtype == activation dtype) +python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \ + --json /tmp/oink_rmsnorm_fwd_quack_suite_wsame.json +``` + +### Fused Add + RMSNorm (vLLM-style, in-place) + +This is a good "roofline case study" kernel (heavy read/write traffic, very little extra math): + +```bash +CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \ + --json /tmp/fused_add_rmsnorm_sm100_bf16.json +``` + +Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x` and `residual`). +Quack’s fused kernel produces `out` and `residual_out` out-of-place, so by default the benchmark +times `quack::_rmsnorm_fwd` **plus** two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`) +to match the in-place semantics (integration-realistic). Use `--quack-baseline kernel` to time only +the Quack fused kernel with preallocated outputs. + +### RMSNorm backward + +```bash +python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \ + --csv /tmp/oink_rmsnorm_bwd_quack_suite.csv + +python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \ + --csv /tmp/oink_rmsnorm_bwd_dsv3.csv +``` + +### Softmax (forward + backward) + +```bash +python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \ + --json /tmp/oink_softmax_fwd_bwd_quack_suite.json + +python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \ + --json /tmp/oink_softmax_fwd_bwd_dsv3.json +``` + +### Cross-entropy (forward + backward) + +```bash +python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \ + --json /tmp/oink_cross_entropy_fwd_bwd_quack_suite.json + +python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \ + --json /tmp/oink_cross_entropy_fwd_bwd_dsv3.json +``` + +### LayerNorm forward + +```bash +python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \ + --json /tmp/oink_layernorm_fwd_quack_suite.json + +python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \ + --json /tmp/oink_layernorm_fwd_dsv3.json +``` + +## Notes + +- These scripts intentionally avoid importing any external Oink checkout so the + results reflect the in-tree KernelAgent Oink kernels. +- For RMSNorm, the `rmsnorm_with_stage2` implementation is a **fallback** that + is only used when the pointer-based fast path cannot be used (e.g. when + `weight.dtype != x.dtype`, or when layouts/alignments are incompatible). You + can force it for A/B testing via `KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2=1`. diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py new file mode 100644 index 0000000..ef996ec --- /dev/null +++ b/oink/benchmarks/benchmark/bench_utils.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import csv +import json +import math +import os +import subprocess +import sys +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple + +import torch +from triton.testing import do_bench as triton_do_bench + + +@dataclass(frozen=True) +class DeviceMeta: + device: str + capability: Tuple[int, int] + torch: str + cuda: str + cute_dsl_arch: str + git_sha: str + timestamp: str + + +def _try_git_sha() -> str: + here = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.abspath(os.path.join(here, "..", "..")) + try: + out = subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=repo_root, + stderr=subprocess.DEVNULL, + text=True, + ) + return out.strip() + except Exception: + return "" + + +def collect_device_meta(device: Optional[torch.device] = None) -> DeviceMeta: + if device is None: + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + timestamp = datetime.now().isoformat(timespec="seconds") + return DeviceMeta( + device=str(props.name), + capability=(int(props.major), int(props.minor)), + torch=str(torch.__version__), + cuda=str(getattr(torch.version, "cuda", "unknown")), + cute_dsl_arch=os.environ.get("CUTE_DSL_ARCH", ""), + git_sha=_try_git_sha(), + timestamp=timestamp, + ) + + +def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float: + """Approximate HBM peak bandwidth in GB/s for roofline fractions.""" + if device is None: + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + if sm >= 100: + return 8000.0 + return 2000.0 + + +def do_bench_triton( + fn: Callable[[], Any], *, warmup_ms: int = 25, rep_ms: int = 100 +) -> float: + """Kernel-only timing consistent with the Oink benchmark harnesses.""" + return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean")) + + +def parse_dtype(s: str) -> torch.dtype: + s = s.lower() + if s == "bf16": + return torch.bfloat16 + if s == "fp16": + return torch.float16 + if s == "fp32": + return torch.float32 + raise ValueError(f"Unsupported dtype: {s}") + + +def parse_configs(s: str) -> List[Tuple[int, int]]: + out: List[Tuple[int, int]] = [] + for part in s.split(","): + m, n = part.lower().split("x") + out.append((int(m), int(n))) + return out + + +def quack_suite_configs() -> List[Tuple[int, int, int]]: + """Return (batch, seq, hidden) triples following Quack's common grid (hidden=4096).""" + batch_sizes = [1, 4, 8, 16, 32] + seq_lengths = [8192, 16384, 32768, 65536, 131072] + hidden = 4096 + cfgs: List[Tuple[int, int, int]] = [] + for bs in batch_sizes: + for sl in seq_lengths: + M = bs * sl + if M * hidden > (2**31): + continue + cfgs.append((bs, sl, hidden)) + return cfgs + + +def ensure_oink_src_on_path() -> None: + """Make the in-repo KernelAgent Oink package importable without an editable install.""" + here = os.path.dirname(os.path.abspath(__file__)) + oink_src = os.path.abspath(os.path.join(here, "..", "..", "src")) + if oink_src not in sys.path: + sys.path.insert(0, oink_src) + + +def write_csv(path: str, rows: Sequence[Dict[str, Any]]) -> None: + if not rows: + return + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + file_exists = os.path.exists(path) + with open(path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=sorted(rows[0].keys())) + if not file_exists: + writer.writeheader() + for row in rows: + writer.writerow(row) + + +def write_json( + path: str, + meta: DeviceMeta, + rows: Sequence[Dict[str, Any]], + *, + extra: Dict[str, Any] | None = None, +) -> None: + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + payload: Dict[str, Any] = { + "meta": {**asdict(meta), **(extra or {})}, + "rows": list(rows), + } + with open(path, "w") as f: + json.dump(payload, f, indent=2) + + +def iter_row_blocks(M: int, block_rows: int) -> Iterable[Tuple[int, int]]: + """Yield (start, end) row index ranges for a 2D (M, N) matrix. + + The intent is to make correctness references for large tensors tractable + without materializing full float32 intermediates. + """ + if M < 0: + raise ValueError(f"M must be non-negative, got {M}") + if block_rows <= 0: + raise ValueError(f"block_rows must be > 0, got {block_rows}") + for start in range(0, M, block_rows): + yield start, min(M, start + block_rows) + + +@dataclass +class ErrorStats: + """Numerical error stats between an output and a reference. + + Notes: + - `max_abs` and `rel_l2` are computed exactly (streamed). + - `p99_abs` is computed over a deterministic strided sample of abs error + values (to keep very large tensors tractable). + """ + + max_abs: float + p99_abs: float + rel_l2: float + p99_sample_elems: int + p99_sample_stride: int + + +class ErrorStatsAccumulator: + """Stream error stats over (output_block, ref_block) pairs. + + This is intended for large 2D tensors where we compute reference results + block-by-block to avoid materializing full float32 intermediates. + """ + + def __init__(self, *, total_elems: int, p99_target_samples: int = 1_000_000): + if total_elems <= 0: + raise ValueError(f"total_elems must be > 0, got {total_elems}") + if p99_target_samples <= 0: + raise ValueError( + f"p99_target_samples must be > 0, got {p99_target_samples}" + ) + self.total_elems = int(total_elems) + self.p99_target_samples = int(p99_target_samples) + # Deterministic strided sampling across the flattened tensor order. + self.sample_stride = max(1, self.total_elems // self.p99_target_samples) + self._global_offset = 0 + + self._max_abs = 0.0 + self._err_sq_sum = 0.0 + self._ref_sq_sum = 0.0 + self._abs_err_samples: List[torch.Tensor] = [] + + def update(self, out: torch.Tensor, ref: torch.Tensor) -> None: + if out.shape != ref.shape: + raise ValueError( + f"shape mismatch: out={tuple(out.shape)} ref={tuple(ref.shape)}" + ) + + # Compute error in float32 for stable reductions. + err_f32 = (out - ref).to(torch.float32) + abs_err = err_f32.abs() + + # Exact reductions. + self._max_abs = max(self._max_abs, float(abs_err.max().item())) + self._err_sq_sum += float((err_f32 * err_f32).sum(dtype=torch.float64).item()) + ref_f32 = ref.to(torch.float32) + self._ref_sq_sum += float((ref_f32 * ref_f32).sum(dtype=torch.float64).item()) + + # Deterministic strided sample for p99_abs. + flat = abs_err.flatten() + block_elems = int(flat.numel()) + if block_elems <= 0: + return + + stride = int(self.sample_stride) + first = (-int(self._global_offset)) % stride + if first < block_elems: + idx = torch.arange( + first, block_elems, step=stride, device=flat.device, dtype=torch.int64 + ) + # Gather a modest number of values (≈ block_elems/stride). + vals = ( + flat.index_select(0, idx).detach().to(device="cpu", dtype=torch.float32) + ) + self._abs_err_samples.append(vals) + + self._global_offset += block_elems + + def finalize(self) -> ErrorStats: + if self._abs_err_samples: + samples = torch.cat(self._abs_err_samples, dim=0) + if samples.numel() > self.p99_target_samples: + samples = samples[: self.p99_target_samples] + p99 = ( + float(torch.quantile(samples, 0.99).item()) + if samples.numel() > 0 + else 0.0 + ) + sample_elems = int(samples.numel()) + else: + p99 = 0.0 + sample_elems = 0 + + denom = math.sqrt(self._ref_sq_sum) if self._ref_sq_sum > 0 else 0.0 + rel_l2 = (math.sqrt(self._err_sq_sum) / denom) if denom > 0 else 0.0 + + return ErrorStats( + max_abs=float(self._max_abs), + p99_abs=float(p99), + rel_l2=float(rel_l2), + p99_sample_elems=int(sample_elems), + p99_sample_stride=int(self.sample_stride), + ) + + +def error_stats_to_row(prefix: str, stats: ErrorStats) -> Dict[str, Any]: + """Flatten ErrorStats into JSON-friendly row fields.""" + return { + f"{prefix}_max_abs": float(stats.max_abs), + f"{prefix}_p99_abs": float(stats.p99_abs), + f"{prefix}_rel_l2": float(stats.rel_l2), + f"{prefix}_p99_sample_elems": int(stats.p99_sample_elems), + f"{prefix}_p99_sample_stride": int(stats.p99_sample_stride), + } diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py new file mode 100644 index 0000000..3c8bf44 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py @@ -0,0 +1,498 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import cross_entropy as oink_ce # noqa: E402 + +try: + from quack.cross_entropy import cross_entropy_bwd as quack_ce_bwd # type: ignore + from quack.cross_entropy import cross_entropy_fwd as quack_ce_fwd # type: ignore +except Exception: + quack_ce_fwd = None + quack_ce_bwd = None + + +# Match Quack's unit-test defaults (tests/test_cross_entropy.py). +_VERIFY_TOL_LOSS = dict(atol=5e-5, rtol=1e-5) # float32 outputs (loss/lse) +_VERIFY_TOL_DX = { + torch.float32: dict(atol=5e-5, rtol=1e-5), + # FP16 `dx` is low-precision; allow ~1 ulp at typical magnitudes. + torch.float16: dict(atol=1e-3, rtol=1e-3), + # BF16 `dx` is low-precision; allow ~1 ulp at typical magnitudes. + torch.bfloat16: dict(atol=1e-2, rtol=1e-2), +} + + +def bytes_io_model_ce( + M: int, + N: int, + dtype: torch.dtype, + *, + target_dtype: torch.dtype = torch.int64, + mode: str, +) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + t_elem = torch.tensor(0, dtype=target_dtype).element_size() + # Forward: + # read logits (M*N) + read target (M) + write loss (M fp32) + write lse (M fp32) + fwd = M * N * elem + M * t_elem + 2 * M * 4 + # Backward (reduction="none" path): + # read logits (M*N) + read target (M) + read dloss (M fp32) + read lse (M fp32) + write dx (M*N) + bwd = 2 * M * N * elem + M * t_elem + 2 * M * 4 + + if mode == "fwd": + return int(fwd) + if mode == "bwd": + return int(bwd) + if mode == "fwd_bwd": + # Logical IO for dx given (logits, target, dloss): read logits + read target + # + read dloss + write dx. (Intermediate lse/loss are implementation details.) + return int(2 * M * N * elem + M * t_elem + M * 4) + raise ValueError(f"Unsupported mode: {mode}") + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [3072, 6144, 8192, 12288] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity( + logits: torch.Tensor, target: torch.Tensor, *, ignore_index: int +) -> dict[str, object]: + dtype = logits.dtype + ref_block_rows = 512 + dloss = torch.randn( + logits.size(0), device=logits.device, dtype=torch.float32 + ) # upstream grad + + with torch.no_grad(): + loss_o, lse_o = oink_ce.cross_entropy_forward( + logits, target, ignore_index=ignore_index, reduction="none" + ) + dx_o = oink_ce.cross_entropy_backward( + dloss, logits, target, lse_o, ignore_index=ignore_index + ) + dx_fused_o = oink_ce.cross_entropy_fwd_bwd( + dloss, + logits, + target, + ignore_index=ignore_index, + ) + + loss_q = None + lse_q = None + dx_q = None + if quack_ce_fwd is not None and quack_ce_bwd is not None: + loss_q, lse_q = quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=ignore_index, + return_lse=True, + return_dx=False, + inplace_backward=False, + ) + dx_q = quack_ce_bwd( + logits, + target, + dloss, + lse_q, + ignore_index=ignore_index, + inplace_backward=False, + ) + + M = int(logits.shape[0]) + N = int(logits.shape[1]) + loss_acc_ours = ErrorStatsAccumulator( + total_elems=M, p99_target_samples=min(M, 1_000_000) + ) + lse_acc_ours = ErrorStatsAccumulator( + total_elems=M, p99_target_samples=min(M, 1_000_000) + ) + dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + loss_acc_quack = ( + ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000)) + if (quack_ce_fwd is not None and quack_ce_bwd is not None) + else None + ) + lse_acc_quack = ( + ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000)) + if (quack_ce_fwd is not None and quack_ce_bwd is not None) + else None + ) + dx_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if (quack_ce_fwd is not None and quack_ce_bwd is not None) + else None + ) + + # Match Quack tests: compare to a PyTorch reference computed on float32 logits. + # Chunk over rows so we don't materialize a full (M, N) float32 tensor. + for start, end in iter_row_blocks(M, ref_block_rows): + logits_f32 = logits[start:end].float().requires_grad_(True) + target_blk = target[start:end] + dloss_blk = dloss[start:end] + + loss_ref = torch.nn.functional.cross_entropy( + logits_f32, + target_blk, + reduction="none", + ignore_index=ignore_index, + ) + lse_ref = torch.logsumexp(logits_f32, dim=-1) + (dx_ref_f32,) = torch.autograd.grad( + loss_ref, logits_f32, grad_outputs=dloss_blk + ) + dx_ref = dx_ref_f32.to(dtype) + + torch.testing.assert_close( + loss_o[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS + ) + torch.testing.assert_close( + lse_o[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS + ) + torch.testing.assert_close(dx_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]) + torch.testing.assert_close( + dx_fused_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype] + ) + loss_acc_ours.update(loss_o[start:end], loss_ref.detach()) + lse_acc_ours.update(lse_o[start:end], lse_ref.detach()) + dx_acc_ours.update(dx_o[start:end], dx_ref) + dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref) + + if loss_q is not None and lse_q is not None and dx_q is not None: + torch.testing.assert_close( + loss_q[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS + ) + torch.testing.assert_close( + lse_q[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS + ) + torch.testing.assert_close(dx_q[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]) + assert ( + loss_acc_quack is not None + and lse_acc_quack is not None + and dx_acc_quack is not None + ) + loss_acc_quack.update(loss_q[start:end], loss_ref.detach()) + lse_acc_quack.update(lse_q[start:end], lse_ref.detach()) + dx_acc_quack.update(dx_q[start:end], dx_ref) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_loss", loss_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_lse", lse_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize())) + if ( + loss_acc_quack is not None + and lse_acc_quack is not None + and dx_acc_quack is not None + ): + stats.update(error_stats_to_row("quack_err_loss", loss_acc_quack.finalize())) + stats.update(error_stats_to_row("quack_err_lse", lse_acc_quack.finalize())) + stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize())) + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + warmup_ms: int, + iters_ms: int, + mode: str, + verify: bool, + ignore_index: int, +) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]: + device = torch.device("cuda") + logits = 0.1 * torch.randn(M, N, device=device, dtype=dtype) + target = torch.randint(0, N, (M,), device=device, dtype=torch.int64) + # Sprinkle some ignore_index entries for robustness (and to match reduction semantics). + if ignore_index is not None: + mask = torch.rand(M, device=device) < 0.01 + target[mask] = int(ignore_index) + dloss = torch.randn(M, device=device, dtype=torch.float32) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(logits, target, ignore_index=int(ignore_index)) + + bytes_io = bytes_io_model_ce(M, N, dtype, target_dtype=target.dtype, mode=mode) + + if mode == "fwd": + + def fn_oink(): + return oink_ce.cross_entropy_forward( + logits, target, ignore_index=int(ignore_index), reduction="none" + ) + + fn_quack = None + if quack_ce_fwd is not None: + + def fn_quack(): + return quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=int(ignore_index), + return_lse=True, + return_dx=False, + inplace_backward=False, + ) + + elif mode == "bwd": + with torch.no_grad(): + _loss_o, lse_o = oink_ce.cross_entropy_forward( + logits, target, ignore_index=int(ignore_index), reduction="none" + ) + if quack_ce_fwd is not None: + _loss_q, lse_q = quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=int(ignore_index), + return_lse=True, + return_dx=False, + inplace_backward=False, + ) + else: + lse_q = None + + def fn_oink(): + return oink_ce.cross_entropy_backward( + dloss, logits, target, lse_o, ignore_index=int(ignore_index) + ) + + fn_quack = None + if quack_ce_bwd is not None and lse_q is not None: + + def fn_quack(): + return quack_ce_bwd( + logits, + target, + dloss, + lse_q, + ignore_index=int(ignore_index), + inplace_backward=False, + ) + + elif mode == "fwd_bwd": + + def fn_oink(): + return oink_ce.cross_entropy_fwd_bwd( + dloss, + logits, + target, + ignore_index=int(ignore_index), + ) + + fn_quack = None + if quack_ce_fwd is not None and quack_ce_bwd is not None: + + def fn_quack(): + _loss_q, lse_q = quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=int(ignore_index), + return_lse=True, + return_dx=False, + inplace_backward=False, + ) + return quack_ce_bwd( + logits, + target, + dloss, + lse_q, + ignore_index=int(ignore_index), + inplace_backward=False, + ) + + else: + raise ValueError(f"Unsupported mode: {mode}") + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + if fn_quack is None: + return (ms_oink, gbps_oink), None, stats + + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + p.add_argument( + "--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"] + ) + p.add_argument("--ignore-index", type=int, default=-100) + p.add_argument( + "--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only)." + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument( + "--quack-suite", + action="store_true", + help="Run Quack-style batch/seq grid (vocab=4096)", + ) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {3072,6144,8192,12288}", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs PyTorch float32-logits cross entropy)", + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for M, N in cfgs: + print( + f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", + flush=True, + ) + (ms_oink, gbps_oink), quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + mode=str(args.mode), + verify=not args.skip_verify, + ignore_index=int(args.ignore_index), + ) + row: Dict[str, Any] = { + "M": M, + "N": N, + "dtype": args.dtype, + "mode": args.mode, + "ignore_index": int(args.ignore_index), + "ours_ms": ms_oink, + "ours_gbps": gbps_oink, + "ours_tbps": gbps_oink / 1000.0, + "ours_hbm_frac": gbps_oink / hbm_peak, + } + if quack is not None: + ms_q, gbps_q = quack + row.update( + { + "quack_ms": ms_q, + "quack_gbps": gbps_q, + "quack_tbps": gbps_q / 1000.0, + "speedup_vs_quack": ms_q / ms_oink, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "mode-dependent; see bytes_io_model_ce in script", + }, + ) + + headers = ["M", "N", "mode", "ours_ms", "ours_tbps"] + if quack_ce_fwd is not None and quack_ce_bwd is not None: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py new file mode 100644 index 0000000..1787d7d --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py @@ -0,0 +1,376 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark fused_add_rmsnorm (in-place) on SM100. + +This matches vLLM's fused_add_rms_norm semantics: + z = x + residual (stored into residual) + y = RMSNorm(z, w) (stored into x) + +Why this exists: +- It is a common inference hot path (vLLM). +- It is strongly memory-bound (reads/writes two MxN tensors), making it a good + roofline case study for Blackwell. + +Example: + CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \\ + --json /tmp/fused_add_rmsnorm_sm100_bf16.json + +DSv3 suite (Oink vs Quack, multi-shape): + CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --dsv3 \\ + --json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json + +Quack baseline note: +- Oink exposes an **in-place** fused op (writes `x` and `residual` in-place). +- Quack provides an equivalent fused kernel, but typically returns `out` and + `residual_out` (out-of-place) and does not expose a public "update my input + buffers in-place" API. +- For integration realism (vLLM-style semantics) we default to timing: + Quack fused kernel + 2 explicit copies to apply the in-place updates + so the benchmark covers the full semantic cost. +""" + +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_dtype, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402 + +_VERIFY_TOL = { + # Align with Quack's RMSNorm unit-test defaults (tests/test_rmsnorm.py). + torch.float32: dict(atol=1e-4, rtol=1e-3), + torch.float16: dict(atol=1e-2, rtol=1e-3), + torch.bfloat16: dict(atol=1e-1, rtol=1e-2), +} + +try: + # Use the low-level mutating custom op to avoid per-iteration allocations + # (critical for fair comparisons on small/medium M). + from quack.rmsnorm import _rmsnorm_fwd as quack_rmsnorm_fwd_mut # type: ignore +except Exception: + quack_rmsnorm_fwd_mut = None + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def bytes_io_model_fused_add_rmsnorm_inplace(M: int, N: int, dtype: torch.dtype) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + # Read x + read residual + write x + write residual + read weight + return int((4 * M * N + N) * elem) + + +def _verify_parity( + *, + x: torch.Tensor, + residual: torch.Tensor, + w: torch.Tensor, + eps: float, +) -> dict[str, object]: + tol = _VERIFY_TOL[x.dtype] + ref_block_rows = 4096 + M = int(x.shape[0]) + N = int(x.shape[1]) + + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + z_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if quack_rmsnorm_fwd_mut is not None + else None + ) + z_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if quack_rmsnorm_fwd_mut is not None + else None + ) + + x_o = x.clone() + r_o = residual.clone() + out_q = None + res_out_q = None + with torch.no_grad(): + oink_rmsnorm.fused_add_rmsnorm_inplace_(x_o, r_o, w, eps=eps) + + if quack_rmsnorm_fwd_mut is not None: + out_q = torch.empty_like(x) + res_out_q = torch.empty_like(residual) + quack_rmsnorm_fwd_mut( + x, + w, + out_q, + None, # bias + None, # rstd + None, # mean + residual, + res_out_q, + eps, + False, # is_layernorm + ) + + # Pure-PyTorch reference (float32 accumulation), chunked over rows. + M = int(x.shape[0]) + w_f32 = w.float() + for start, end in iter_row_blocks(M, ref_block_rows): + z = x[start:end] + residual[start:end] + zf = z.float() + rstd = torch.rsqrt(zf.square().mean(dim=-1, keepdim=True) + eps) + y_ref = ((zf * rstd) * w_f32).to(x.dtype) + + torch.testing.assert_close(x_o[start:end], y_ref, **tol) + torch.testing.assert_close(r_o[start:end], z, **tol) + y_acc_ours.update(x_o[start:end], y_ref) + z_acc_ours.update(r_o[start:end], z) + if out_q is not None and res_out_q is not None: + torch.testing.assert_close(out_q[start:end], y_ref, **tol) + torch.testing.assert_close(res_out_q[start:end], z, **tol) + assert y_acc_quack is not None and z_acc_quack is not None + y_acc_quack.update(out_q[start:end], y_ref) + z_acc_quack.update(res_out_q[start:end], z) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_residual_out", z_acc_ours.finalize())) + if y_acc_quack is not None and z_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) + stats.update( + error_stats_to_row("quack_err_residual_out", z_acc_quack.finalize()) + ) + return stats + + +def bench_one( + *, + M: int, + N: int, + dtype: torch.dtype, + warmup_ms: int, + iters_ms: int, + verify: bool, + quack_baseline: str, +) -> Dict[str, Any]: + device = torch.device("cuda") + x = torch.randn((M, N), device=device, dtype=dtype) + residual = torch.randn_like(x) + w = torch.randn((N,), device=device, dtype=dtype) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x=x, residual=residual, w=w, eps=1e-6) + + bytes_io = bytes_io_model_fused_add_rmsnorm_inplace(M, N, dtype) + + def fn(): + oink_rmsnorm.fused_add_rmsnorm_inplace_(x, residual, w, eps=1e-6) + + ms = do_bench_triton(fn, warmup_ms=warmup_ms, rep_ms=iters_ms) + + gbps = bytes_io / (ms * 1e-3) / 1e9 + tbps = gbps / 1000.0 + hbm_frac = gbps / detect_hbm_peak_gbps(device) + + row: Dict[str, Any] = dict( + M=int(M), + N=int(N), + dtype="bf16" + if dtype is torch.bfloat16 + else ("fp16" if dtype is torch.float16 else "fp32"), + ours_ms=float(ms), + ours_gbps=float(gbps), + ours_tbps=float(tbps), + ours_hbm_frac=float(hbm_frac), + ) + row.update(stats) + + if quack_rmsnorm_fwd_mut is not None: + x_q = x.clone() + residual_q = residual.clone() + out_q = torch.empty_like(x_q) + res_out_q = torch.empty_like(residual_q) + + def fn_q_kernel(): + quack_rmsnorm_fwd_mut( + x_q, + w, + out_q, + None, # bias + None, # rstd + None, # mean + residual_q, + res_out_q, + 1e-6, + False, # is_layernorm + ) + + if quack_baseline == "kernel": + fn_q = fn_q_kernel + elif quack_baseline == "kernel_inplace": + + def fn_q(): + fn_q_kernel() + # Apply the same in-place semantics as vLLM expects: + # - x is overwritten with y + # - residual is overwritten with z = x + residual + x_q.copy_(out_q) + residual_q.copy_(res_out_q) + + else: + raise ValueError(f"Unknown quack_baseline: {quack_baseline}") + + ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_q = bytes_io / (ms_q * 1e-3) / 1e9 + row.update( + dict( + quack_ms=float(ms_q), + quack_gbps=float(gbps_q), + quack_tbps=float(gbps_q / 1000.0), + speedup_vs_quack=float(ms_q / ms), + ) + ) + + return row + + +def _dtype_label(dtype: torch.dtype) -> str: + if dtype is torch.bfloat16: + return "bf16" + if dtype is torch.float16: + return "fp16" + return "fp32" + + +def _print_table(rows: List[Dict[str, Any]]) -> None: + if not rows: + return + headers = ["M", "N", "ours_ms", "ours_tbps"] + has_quack = any("quack_ms" in r for r in rows) + if has_quack: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument( + "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"] + ) + p.add_argument("--M", type=int, default=65536) + p.add_argument("--N", type=int, default=4096) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument( + "--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)" + ) + p.add_argument( + "--quack-baseline", + type=str, + default="kernel_inplace", + choices=["kernel", "kernel_inplace"], + help=( + "How to time Quack for the in-place fused op.\n" + "- kernel: Quack fused kernel only (preallocated out/residual_out).\n" + "- kernel_inplace: Quack fused kernel + 2 explicit copies to apply " + "in-place semantics (integration-realistic)." + ), + ) + p.add_argument("--skip-verify", action="store_true") + p.add_argument("--json", type=str, default=None) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + meta = collect_device_meta(torch.device("cuda")) + + cfgs = dsv3_configs() if bool(args.dsv3) else [(int(args.M), int(args.N))] + rows: List[Dict[str, Any]] = [] + for M, N in cfgs: + print( + f"bench M={M:<8d} N={N:<6d} dtype={_dtype_label(dtype)} fused_add_rmsnorm ...", + flush=True, + ) + rows.append( + bench_one( + M=int(M), + N=int(N), + dtype=dtype, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + verify=not bool(args.skip_verify), + quack_baseline=str(args.quack_baseline), + ) + ) + + _print_table(rows) + + if args.json: + write_json( + args.json, + meta, + rows, + extra=dict( + io_model_bytes="(4*M*N + N)*elem_size", + warmup_ms=int(args.warmup_ms), + rep_ms=int(args.iters), + method="triton.testing.do_bench(mean)", + note=( + "Oink fused_add_rmsnorm_inplace_ vs Quack baseline " + f"({args.quack_baseline}) when available" + ), + ), + ) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py new file mode 100644 index 0000000..22fb48d --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +HBM roofline microbenchmark for SM100 (GB200 / Blackwell). + +This script measures a STREAM-like bandwidth ceiling using a simple Triton kernel +that performs a large contiguous copy (read + write) and/or triad (read + read + write) +over a large buffer. + +Why this exists: +- The benchmark harnesses for Oink ops report an "ours_tbps" derived from an IO model. +- For roofline discussions, comparing against a *measured* device bandwidth ceiling + is often more meaningful than quoting a marketing/theoretical spec. + +Example: + CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op copy --gb 2 + CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype fp16 --op triad --gb 2 +""" + +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Tuple + +import torch +import triton +import triton.language as tl + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +from bench_utils import ( # noqa: E402 + collect_device_meta, + do_bench_triton, + parse_dtype, + write_json, +) + + +@triton.jit +def _copy_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask, other=0) + tl.store(y_ptr + offsets, x, mask=mask) + + +@triton.jit +def _triad_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask, other=0) + y = tl.load(y_ptr + offsets, mask=mask, other=0) + tl.store(y_ptr + offsets, x + y, mask=mask) + + +def _bytes_moved(n_elements: int, elem_size: int, *, op: str) -> int: + if op == "copy": + return int(2 * n_elements * elem_size) # read x + write y + if op == "triad": + return int(3 * n_elements * elem_size) # read x + read y + write y + raise ValueError(f"Unsupported op: {op}") + + +def bench_one( + *, + n_elements: int, + dtype: torch.dtype, + op: str, + block: int, + num_warps: int, + warmup_ms: int, + iters_ms: int, +) -> Tuple[float, float]: + device = torch.device("cuda") + x = torch.empty((n_elements,), device=device, dtype=dtype) + y = torch.empty_like(x) + # Avoid pathological compression-friendly patterns (e.g. all-zeros) that can + # artificially inflate apparent bandwidth on some GPUs. Random-ish data is + # a closer match to ML workloads. + x.uniform_(-1, 1) + y.uniform_(-1, 1) + + grid = (triton.cdiv(n_elements, block),) + + if op == "copy": + + def launch(): + _copy_kernel[grid]( + x, + y, + n_elements, + BLOCK=block, + num_warps=num_warps, + num_stages=4, + ) + + elif op == "triad": + + def launch(): + _triad_kernel[grid]( + x, + y, + n_elements, + BLOCK=block, + num_warps=num_warps, + num_stages=4, + ) + + else: + raise ValueError(f"Unsupported op: {op}") + + # Force compilation out of the timed region. + launch() + torch.cuda.synchronize() + + ms = do_bench_triton(launch, warmup_ms=warmup_ms, rep_ms=iters_ms) + moved = _bytes_moved(n_elements, x.element_size(), op=op) + tbps = moved / (ms * 1e-3) / 1e12 + return ms, tbps + + +def _print_summary(rows: List[Dict[str, Any]]) -> None: + if not rows: + return + best = max(rows, key=lambda r: float(r["tbps"])) + print("\nSummary (STREAM-like):") + print( + f"- best_tbps: {best['tbps']:.3f} TB/s ({best['op']}, BLOCK={best['block']}, warps={best['num_warps']})" + ) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument( + "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"] + ) + p.add_argument("--op", type=str, default="copy", choices=["copy", "triad", "both"]) + p.add_argument( + "--gb", type=float, default=2.0, help="Size per tensor in GB (default: 2)" + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument( + "--iters", type=int, default=100, help="rep_ms for do_bench (default: 100)" + ) + p.add_argument( + "--json", type=str, default=None, help="Write JSON results to this path" + ) + p.add_argument( + "--no-sweep", + action="store_true", + help="Disable tuning sweep; run a single config", + ) + p.add_argument( + "--block", type=int, default=2048, help="BLOCK size when --no-sweep is set" + ) + p.add_argument( + "--warps", type=int, default=8, help="num_warps when --no-sweep is set" + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + cap = (int(props.major), int(props.minor)) + if cap != (10, 0): + raise RuntimeError(f"Expected SM100 (10,0), got {cap} ({props.name})") + + elem_size = torch.tensor(0, dtype=dtype).element_size() + bytes_per_tensor = int(args.gb * (1024**3)) + n_elements = max(1, bytes_per_tensor // elem_size) + + ops: List[str] + if args.op == "both": + ops = ["copy", "triad"] + else: + ops = [args.op] + + if args.no_sweep: + sweep: List[Tuple[int, int]] = [(int(args.block), int(args.warps))] + else: + # A tiny hand-tuned sweep that keeps compile overhead reasonable. + sweep = [ + (1024, 4), + (1024, 8), + (2048, 4), + (2048, 8), + (4096, 8), + ] + + print(f"Running on {props.name} (SM{props.major}{props.minor})") + print(f"- dtype: {args.dtype} (elem={elem_size}B)") + print( + f"- n_elements: {n_elements:,} (~{(n_elements * elem_size) / (1024**3):.2f} GiB per tensor)" + ) + print(f"- ops: {ops}") + print(f"- sweep: {sweep}") + + meta = collect_device_meta(device) + rows: List[Dict[str, Any]] = [] + for op in ops: + for block, warps in sweep: + ms, tbps = bench_one( + n_elements=n_elements, + dtype=dtype, + op=op, + block=block, + num_warps=warps, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + ) + rows.append( + dict( + op=op, + dtype=str(args.dtype), + n_elements=int(n_elements), + elem_size_B=int(elem_size), + block=int(block), + num_warps=int(warps), + warmup_ms=int(args.warmup_ms), + rep_ms=int(args.iters), + ms=float(ms), + tbps=float(tbps), + ) + ) + print( + f"- {op:5s} BLOCK={block:4d} warps={warps}: {tbps:.3f} TB/s ({ms:.4f} ms)" + ) + + _print_summary(rows) + + if args.json: + # Write meta + detailed rows for reproducibility. + extra = dict( + bytes_model="copy:2*N*elem, triad:3*N*elem", + bytes_per_tensor=int(bytes_per_tensor), + gb_per_tensor=float(args.gb), + ) + write_json(args.json, meta, rows, extra=extra) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py new file mode 100644 index 0000000..20895b7 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py @@ -0,0 +1,451 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import layernorm as oink_ln # noqa: E402 + +try: + # Quack exposes LayerNorm through the RMSNorm module (is_layernorm=True path). + from quack.rmsnorm import layernorm_fwd as quack_layernorm # type: ignore +except Exception: + quack_layernorm = None + +_VERIFY_TOL_Y = { + # Match Quack's unit-test defaults (tests/test_layernorm.py). + torch.float32: dict(atol=1e-4, rtol=1e-4), + torch.float16: dict(atol=1e-3, rtol=1e-3), + torch.bfloat16: dict(atol=1e-2, rtol=1e-2), +} + +# Quack checks rstd/mean (fp32) with a tighter fixed tolerance. +_VERIFY_TOL_STATS = dict(atol=6e-4, rtol=6e-4) + + +def bytes_io_model_layernorm( + M: int, + N: int, + dtype: torch.dtype, + *, + has_bias: bool, + return_rstd: bool, + return_mean: bool, + weight_dtype: torch.dtype = torch.float32, +) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + w_elem = torch.tensor(0, dtype=weight_dtype).element_size() + total = 0 + # Read x + write y + total += 2 * M * N * elem + # Read weight (+ optional bias) along feature dim + total += N * w_elem + if has_bias: + total += N * w_elem + # Optional per-row stats (fp32) + if return_rstd: + total += M * 4 + if return_mean: + total += M * 4 + return int(total) + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity( + x: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor | None, + *, + eps: float, + return_rstd: bool, + return_mean: bool, +) -> dict[str, object]: + tol_y = _VERIFY_TOL_Y[x.dtype] + ref_block_rows = 4096 + M = int(x.shape[0]) + N = int(x.shape[1]) + + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if (quack_layernorm is not None and b is None) + else None + ) + with torch.no_grad(): + ours = oink_ln.layernorm( + x, + w, + bias=b, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + quack = None + if quack_layernorm is not None and b is None: + quack = quack_layernorm( + x, + w, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + torch.cuda.synchronize() + + def _unpack(out): + if return_rstd and return_mean: + y, rstd, mean = out + elif return_rstd and not return_mean: + y, rstd = out + mean = None + elif return_mean and not return_rstd: + y, mean = out + rstd = None + else: + y, rstd, mean = out, None, None + return y, rstd, mean + + y_o, rstd_o, mean_o = _unpack(ours) + y_q, rstd_q, mean_q = _unpack(quack) if quack is not None else (None, None, None) + + # Pure-PyTorch reference (float32 accumulation), matching Quack's unit tests: + # - compute ref output via F.layer_norm on float32 + # - compute mean/rstd from float32 input + rstd_ref_all = ( + torch.empty((M,), device=x.device, dtype=torch.float32) if return_rstd else None + ) + mean_ref_all = ( + torch.empty((M,), device=x.device, dtype=torch.float32) if return_mean else None + ) + + for start, end in iter_row_blocks(M, ref_block_rows): + x_f32 = x[start:end].float() + y_ref_f32 = torch.nn.functional.layer_norm(x_f32, w.shape, w, b, eps) + y_ref = y_ref_f32.to(x.dtype) + torch.testing.assert_close(y_o[start:end], y_ref, **tol_y) + y_acc_ours.update(y_o[start:end], y_ref) + if y_q is not None: + torch.testing.assert_close(y_q[start:end], y_ref, **tol_y) + assert y_acc_quack is not None + y_acc_quack.update(y_q[start:end], y_ref) + + # Per-row stats in fp32, as in Quack's tests. + if return_rstd or return_mean: + mean_f32 = x_f32.mean(dim=-1) + if return_mean: + assert mean_ref_all is not None + mean_ref_all[start:end] = mean_f32 + if return_rstd: + var_f32 = ((x_f32 - mean_f32.unsqueeze(1)) ** 2).mean(dim=-1) + rstd_ref = 1.0 / torch.sqrt(var_f32 + eps) + assert rstd_ref_all is not None + rstd_ref_all[start:end] = rstd_ref + + assert rstd_o is not None + torch.testing.assert_close( + rstd_o[start:end], rstd_ref, **_VERIFY_TOL_STATS + ) + if rstd_q is not None: + torch.testing.assert_close( + rstd_q[start:end], rstd_ref, **_VERIFY_TOL_STATS + ) + + if return_mean: + mean_ref = mean_f32 + assert mean_o is not None + torch.testing.assert_close( + mean_o[start:end], mean_ref, **_VERIFY_TOL_STATS + ) + if mean_q is not None: + torch.testing.assert_close( + mean_q[start:end], mean_ref, **_VERIFY_TOL_STATS + ) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + if y_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) + + if return_rstd: + assert rstd_o is not None and rstd_ref_all is not None + rstd_acc_ours = ErrorStatsAccumulator( + total_elems=int(rstd_ref_all.numel()), + p99_target_samples=int(rstd_ref_all.numel()), + ) + rstd_acc_ours.update(rstd_o, rstd_ref_all) + stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize())) + if rstd_q is not None: + rstd_acc_quack = ErrorStatsAccumulator( + total_elems=int(rstd_ref_all.numel()), + p99_target_samples=int(rstd_ref_all.numel()), + ) + rstd_acc_quack.update(rstd_q, rstd_ref_all) + stats.update( + error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize()) + ) + + if return_mean: + assert mean_o is not None and mean_ref_all is not None + mean_acc_ours = ErrorStatsAccumulator( + total_elems=int(mean_ref_all.numel()), + p99_target_samples=int(mean_ref_all.numel()), + ) + mean_acc_ours.update(mean_o, mean_ref_all) + stats.update(error_stats_to_row("ours_err_mean", mean_acc_ours.finalize())) + if mean_q is not None: + mean_acc_quack = ErrorStatsAccumulator( + total_elems=int(mean_ref_all.numel()), + p99_target_samples=int(mean_ref_all.numel()), + ) + mean_acc_quack.update(mean_q, mean_ref_all) + stats.update( + error_stats_to_row("quack_err_mean", mean_acc_quack.finalize()) + ) + + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + eps: float, + warmup_ms: int, + iters_ms: int, + verify: bool, + return_rstd: bool, + return_mean: bool, + has_bias: bool, +) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=torch.float32) + b = torch.randn(N, device=device, dtype=torch.float32) if has_bias else None + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity( + x, w, b, eps=eps, return_rstd=return_rstd, return_mean=return_mean + ) + + bytes_io = bytes_io_model_layernorm( + M, + N, + dtype, + has_bias=has_bias, + return_rstd=return_rstd, + return_mean=return_mean, + weight_dtype=w.dtype, + ) + + def fn_oink(): + return oink_ln.layernorm( + x, + w, + bias=b, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + if quack_layernorm is None or has_bias: + return (ms_oink, gbps_oink), None, stats + + def fn_quack(): + return quack_layernorm( + x, + w, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + p.add_argument("--eps", type=float, default=1e-6) + p.add_argument("--return-rstd", action="store_true") + p.add_argument("--return-mean", action="store_true") + p.add_argument( + "--with-bias", + action="store_true", + help="Benchmark bias path (Quack compare skipped)", + ) + p.add_argument( + "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)." + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument( + "--quack-suite", + action="store_true", + help="Run Quack-style batch/seq grid (hidden=4096)", + ) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference; Quack compare skipped when bias is enabled)", + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + eps = float(args.eps) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for M, N in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) + (ms_oink, gbps_oink), quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + eps=eps, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + verify=not args.skip_verify, + return_rstd=bool(args.return_rstd), + return_mean=bool(args.return_mean), + has_bias=bool(args.with_bias), + ) + row: Dict[str, Any] = { + "M": M, + "N": N, + "dtype": args.dtype, + "eps": eps, + "return_rstd": bool(args.return_rstd), + "return_mean": bool(args.return_mean), + "with_bias": bool(args.with_bias), + "ours_ms": ms_oink, + "ours_gbps": gbps_oink, + "ours_tbps": gbps_oink / 1000.0, + "ours_hbm_frac": gbps_oink / hbm_peak, + } + if quack is not None: + ms_q, gbps_q = quack + row.update( + { + "quack_ms": ms_q, + "quack_gbps": gbps_q, + "quack_tbps": gbps_q / 1000.0, + "speedup_vs_quack": ms_q / ms_oink, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "see bytes_io_model_layernorm in script", + }, + ) + + headers = ["M", "N", "ours_ms", "ours_tbps"] + if quack_layernorm is not None and (not args.with_bias): + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py new file mode 100644 index 0000000..50ecb2e --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py @@ -0,0 +1,464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import csv +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from triton.testing import do_bench as triton_do_bench + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + ensure_oink_src_on_path, + error_stats_to_row, + iter_row_blocks, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402 + +try: + from quack.rmsnorm import rmsnorm_bwd as quack_rmsnorm_bwd # type: ignore +except Exception: + quack_rmsnorm_bwd = None + +_VERIFY_TOL_DX = { + # Match Quack's unit-test defaults (tests/test_rmsnorm.py). + torch.float32: dict(atol=1e-4, rtol=1e-3), + torch.float16: dict(atol=1e-2, rtol=1e-3), + torch.bfloat16: dict(atol=1e-1, rtol=1e-2), +} + + +def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float: + """Approximate HBM peak bandwidth in GB/s for roofline fractions.""" + if device is None: + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + if sm >= 100: + return 8000.0 + return 2000.0 + + +@dataclass +class Result: + ms: float + gbps: float + + +def do_bench_triton(fn, warmup_ms: int = 25, rep_ms: int = 100) -> float: + # Kernel-only timing consistent with the existing Oink forward harness. + return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean")) + + +def bytes_io_model_bwd( + M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32 +) -> int: + """A simple IO model for RMSNorm backward. + + This intentionally ignores partial-reduction scratch buffers (`dw_partial` / + `db_partial`) since those are highly implementation-specific and depend on + sm_count; we still report speedups and times regardless. + """ + elem = torch.tensor(0, dtype=dtype).element_size() + w_elem = torch.tensor(0, dtype=weight_dtype).element_size() + # Read x + dout + write dx + total = 3 * M * N * elem + # Read weight + write dw + total += 2 * N * w_elem + # Read rstd (fp32 per row) + total += M * 4 + return int(total) + + +def parse_dtype(s: str) -> torch.dtype: + s = s.lower() + if s == "bf16": + return torch.bfloat16 + if s == "fp16": + return torch.float16 + if s == "fp32": + return torch.float32 + raise ValueError(f"Unsupported dtype: {s}") + + +def parse_configs(s: str) -> List[Tuple[int, int]]: + out: List[Tuple[int, int]] = [] + for part in s.split(","): + m, n = part.lower().split("x") + out.append((int(m), int(n))) + return out + + +def quack_suite_configs() -> List[Tuple[int, int, int]]: + """Return (batch, seq, hidden) triples following Quack's grid (hidden=4096).""" + batch_sizes = [1, 4, 8, 16, 32] + seq_lengths = [8192, 16384, 32768, 65536, 131072] + hidden = 4096 + cfgs: List[Tuple[int, int, int]] = [] + for bs in batch_sizes: + for sl in seq_lengths: + M = bs * sl + if M * hidden > (2**31): + continue + cfgs.append((bs, sl, hidden)) + return cfgs + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity( + x: torch.Tensor, + w: torch.Tensor, + dout: torch.Tensor, + rstd: torch.Tensor, + *, + has_bias: bool, + has_residual: bool, +) -> dict[str, object]: + tol_dx = _VERIFY_TOL_DX[x.dtype] + ref_block_rows = 1024 + M, N = int(x.shape[0]), int(x.shape[1]) + + dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + dx_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if quack_rmsnorm_bwd is not None + else None + ) + + with torch.no_grad(): + dx_oink, dw_oink, db_oink, dres_oink = oink_rmsnorm.rmsnorm_backward( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=has_bias, + has_residual=has_residual, + ) + + dx_quack = None + dw_quack = None + db_quack = None + dres_quack = None + if quack_rmsnorm_bwd is not None: + dx_quack, dw_quack, db_quack, dres_quack = quack_rmsnorm_bwd( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=has_bias, + has_residual=has_residual, + ) + torch.cuda.synchronize() + + # Pure-PyTorch reference, matching Quack's rmsnorm_bwd_ref (float32 math for x_hat). + # Chunk over rows to avoid materializing an (M, N) float32 tensor for large shapes. + dw_accum = torch.zeros((N,), device=x.device, dtype=torch.float32) + w_f32 = w.float() + for start, end in iter_row_blocks(M, ref_block_rows): + x_f32 = x[start:end].float() + rstd_blk = rstd[start:end] + x_hat = x_f32 * rstd_blk.unsqueeze(1) + # Match Quack/PyTorch reference behavior: gradient math uses float32 + # intermediates even when (x, w, dout) are bf16/fp16. + dout_f32 = dout[start:end].float() + wdy = dout_f32 * w_f32 + c1 = (x_hat * wdy).mean(dim=-1, keepdim=True) + dx_ref = ((wdy - x_hat * c1) * rstd_blk.unsqueeze(1)).to(x.dtype) + + torch.testing.assert_close(dx_oink[start:end], dx_ref, **tol_dx) + dx_acc_ours.update(dx_oink[start:end], dx_ref) + if dx_quack is not None: + torch.testing.assert_close(dx_quack[start:end], dx_ref, **tol_dx) + assert dx_acc_quack is not None + dx_acc_quack.update(dx_quack[start:end], dx_ref) + + if dw_oink is not None: + dw_accum += (dout_f32 * x_hat).sum(dim=0) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize())) + if dx_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize())) + + if dw_oink is not None: + dw_ref = dw_accum.to(w.dtype) + if w.dtype == torch.float32: + # Weight grad is sensitive to reduction order; use a slightly larger + # absolute tolerance in the suite harness (Quack's unit tests use + # smaller M, where dw is typically tighter). + dw_tol = dict(atol=2e-3, rtol=1e-3) + else: + # For fp16/bf16 weights, `dw` is low-precision and grows with M; use an + # ulp/magnitude-aware tolerance rather than a fixed epsilon. + dw_ref_f32 = dw_ref.to(torch.float32) + dw_oink_f32 = dw_oink.to(torch.float32) + scale = float(dw_ref_f32.abs().max().item()) + dw_atol = max(2.0 * torch.finfo(w.dtype).eps * scale, 1e-3) + dw_tol = dict(atol=dw_atol, rtol=1e-3) + torch.testing.assert_close(dw_oink_f32, dw_ref_f32, **dw_tol) + if dw_quack is not None: + torch.testing.assert_close( + dw_quack.to(torch.float32), dw_ref_f32, **dw_tol + ) + dw_tol = None # handled above + if dw_tol is not None: + torch.testing.assert_close(dw_oink, dw_ref, **dw_tol) + if dw_quack is not None: + torch.testing.assert_close(dw_quack, dw_ref, **dw_tol) + + # Record weight-grad error stats (small, so exact p99 over the full vector). + dw_acc_ours = ErrorStatsAccumulator( + total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel()) + ) + dw_acc_ours.update(dw_oink, dw_ref) + stats.update(error_stats_to_row("ours_err_dw", dw_acc_ours.finalize())) + if dw_quack is not None: + dw_acc_quack = ErrorStatsAccumulator( + total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel()) + ) + dw_acc_quack.update(dw_quack, dw_ref) + stats.update(error_stats_to_row("quack_err_dw", dw_acc_quack.finalize())) + + assert db_oink is None and db_quack is None + assert dres_oink is None and dres_quack is None + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + weight_dtype: torch.dtype, + iters_ms: int, + eps: float, + warmup_ms: int, + verify: bool, +) -> Tuple[Result, Result | None, dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=weight_dtype) + dout = torch.randn(M, N, device=device, dtype=dtype) + # rstd is fp32 per row; compute once outside the timed region. + with torch.no_grad(): + xf = x.float() + rstd = torch.rsqrt(xf.square().mean(dim=-1) + eps).to(torch.float32) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x, w, dout, rstd, has_bias=False, has_residual=False) + + def fn_oink(): + return oink_rmsnorm.rmsnorm_backward( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=False, + has_residual=False, + ) + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + bytes_io = bytes_io_model_bwd(M, N, dtype, weight_dtype=w.dtype) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + ours = Result(ms=ms_oink, gbps=gbps_oink) + + if quack_rmsnorm_bwd is None: + return ours, None, stats + + def fn_quack(): + return quack_rmsnorm_bwd( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=False, + has_residual=False, + ) + + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return ours, Result(ms=ms_quack, gbps=gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + p.add_argument( + "--weight-dtype", + type=str, + default="fp32", + choices=["same", "fp16", "bf16", "fp32"], + help="RMSNorm weight dtype. `same` matches activation dtype.", + ) + p.add_argument("--eps", type=float, default=1e-6) + p.add_argument( + "--iters", + type=int, + default=100, + help="Triton do_bench rep_ms (kernel-only).", + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument( + "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid" + ) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs a pure-PyTorch RMSNorm backward reference)", + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + if args.weight_dtype == "same": + weight_dtype = dtype + else: + weight_dtype = parse_dtype(args.weight_dtype) + eps = float(args.eps) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + + rows_out: list[dict[str, object]] = [] + + for M, N in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) + ours, quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + weight_dtype=weight_dtype, + iters_ms=int(args.iters), + eps=eps, + warmup_ms=int(args.warmup_ms), + verify=not args.skip_verify, + ) + + row: dict[str, object] = { + "M": M, + "N": N, + "dtype": args.dtype, + "weight_dtype": args.weight_dtype, + "ours_ms": ours.ms, + "ours_gbps": ours.gbps, + "ours_tbps": ours.gbps / 1000.0, + "ours_hbm_frac": ours.gbps / hbm_peak, + } + if quack is not None: + row.update( + { + "quack_ms": quack.ms, + "quack_gbps": quack.gbps, + "quack_tbps": quack.gbps / 1000.0, + "speedup_vs_quack": quack.ms / ours.ms, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + file_exists = os.path.exists(args.csv) + with open(args.csv, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=sorted(row.keys())) + if not file_exists: + writer.writeheader() + writer.writerow(row) + + if args.json is not None: + meta = collect_device_meta(device) + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "see bytes_io_model_bwd in script", + "weight_dtype": str(args.weight_dtype), + }, + ) + + # Print a small summary table. + headers = ["M", "N", "dtype", "ours_ms", "ours_tbps", "ours_hbm_frac"] + if quack_rmsnorm_bwd is not None: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: list[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py new file mode 100644 index 0000000..39e6cd7 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py @@ -0,0 +1,384 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402 + +try: + from quack.rmsnorm import rmsnorm_fwd as quack_rmsnorm_fwd # type: ignore +except Exception: + quack_rmsnorm_fwd = None + +_VERIFY_TOL_Y = { + # Match Quack's unit-test defaults (tests/test_rmsnorm.py). + torch.float32: dict(atol=1e-4, rtol=1e-3), + torch.float16: dict(atol=1e-2, rtol=1e-3), + # NOTE: bf16 ulp grows with magnitude; a slightly larger rtol is more robust + # for the large-M suite shapes (and fused paths that can see larger values). + torch.bfloat16: dict(atol=1e-1, rtol=1e-2), +} + +_VERIFY_TOL_RSTD = { + torch.float32: dict(atol=1e-5, rtol=1e-5), + torch.float16: dict(atol=1e-3, rtol=1e-3), + torch.bfloat16: dict(atol=1e-3, rtol=1e-3), +} + + +def bytes_io_model_fwd( + M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32 +) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + w_elem = torch.tensor(0, dtype=weight_dtype).element_size() + # Read x + write y + total = 2 * M * N * elem + # Read weight + total += N * w_elem + return int(total) + + +def dsv3_configs() -> List[Tuple[int, int]]: + # DSv3-ish hidden sizes used throughout the Oink/Quack SM100 suite tables. + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity( + x: torch.Tensor, + w: torch.Tensor, + *, + eps: float, + store_rstd: bool, +) -> dict[str, object]: + tol_y = _VERIFY_TOL_Y[x.dtype] + tol_rstd = _VERIFY_TOL_RSTD[x.dtype] + ref_block_rows = 4096 + M = int(x.shape[0]) + N = int(x.shape[1]) + + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if quack_rmsnorm_fwd is not None + else None + ) + + with torch.no_grad(): + y_o, rstd_o, res_o = oink_rmsnorm.rmsnorm_forward( + x, + weight=w, + bias=None, + residual=None, + eps=eps, + store_rstd=store_rstd, + ) + y_q = None + rstd_q = None + if quack_rmsnorm_fwd is not None: + # Quack returns (out, residual_out, rstd). + y_q, res_q, rstd_q = quack_rmsnorm_fwd( + x, + w, + bias=None, + residual=None, + out_dtype=None, + residual_dtype=None, + eps=eps, + store_rstd=store_rstd, + ) + + # Pure-PyTorch reference (float32 accumulation), chunked over rows to avoid + # materializing an (M, N) float32 tensor for large Quack-suite shapes. + w_f32 = w.float() + rstd_ref = torch.empty((M,), device=x.device, dtype=torch.float32) + for start, end in iter_row_blocks(M, ref_block_rows): + x_f32 = x[start:end].float() + rstd_blk = torch.rsqrt(x_f32.square().mean(dim=-1) + eps) + rstd_ref[start:end] = rstd_blk + + y_ref_blk_f32 = (x_f32 * rstd_blk.unsqueeze(1)) * w_f32 + y_ref_blk = y_ref_blk_f32.to(x.dtype) + torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol_y) + y_acc_ours.update(y_o[start:end], y_ref_blk) + if y_q is not None: + torch.testing.assert_close(y_q[start:end], y_ref_blk, **tol_y) + assert y_acc_quack is not None + y_acc_quack.update(y_q[start:end], y_ref_blk) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + if y_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) + + if store_rstd: + assert rstd_o is not None + torch.testing.assert_close(rstd_o, rstd_ref, **tol_rstd) + if y_q is not None: + assert rstd_q is not None + torch.testing.assert_close(rstd_q, rstd_ref, **tol_rstd) + # Stats for rstd are cheap (M elements); compute exact p99 over all rows. + rstd_acc_ours = ErrorStatsAccumulator( + total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel()) + ) + rstd_acc_ours.update(rstd_o, rstd_ref) + stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize())) + if rstd_q is not None: + rstd_acc_quack = ErrorStatsAccumulator( + total_elems=int(rstd_ref.numel()), + p99_target_samples=int(rstd_ref.numel()), + ) + rstd_acc_quack.update(rstd_q, rstd_ref) + stats.update( + error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize()) + ) + # Residual output semantics differ slightly across implementations: + # - Oink returns `None` when residual is None. + # - Quack returns `x` as a safe alias in that case. + # + # For parity we focus on `y` (and optional `rstd`) for the residual=None path. + assert res_o is None + if quack_rmsnorm_fwd is not None: + assert res_q is x + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + weight_dtype: torch.dtype, + eps: float, + warmup_ms: int, + iters_ms: int, + verify: bool, + store_rstd: bool, +) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=weight_dtype) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x, w, eps=eps, store_rstd=store_rstd) + + bytes_io = bytes_io_model_fwd(M, N, dtype, weight_dtype=w.dtype) + + def fn_oink(): + return oink_rmsnorm.rmsnorm_forward( + x, + weight=w, + bias=None, + residual=None, + eps=eps, + store_rstd=store_rstd, + ) + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + if quack_rmsnorm_fwd is None: + return (ms_oink, gbps_oink), None, stats + + def fn_quack(): + return quack_rmsnorm_fwd( + x, + w, + bias=None, + residual=None, + out_dtype=None, + residual_dtype=None, + eps=eps, + store_rstd=store_rstd, + ) + + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + p.add_argument( + "--weight-dtype", + type=str, + default="fp32", + choices=["same", "fp16", "bf16", "fp32"], + help="RMSNorm weight dtype. `same` matches activation dtype (vLLM-style inference).", + ) + p.add_argument("--eps", type=float, default=1e-6) + p.add_argument( + "--store-rstd", action="store_true", help="Also write rstd (fp32 per row)" + ) + p.add_argument( + "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)." + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument( + "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid" + ) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference)", + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + if args.weight_dtype == "same": + weight_dtype = dtype + else: + weight_dtype = parse_dtype(args.weight_dtype) + eps = float(args.eps) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for M, N in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) + (ms_oink, gbps_oink), quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + weight_dtype=weight_dtype, + eps=eps, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + verify=not args.skip_verify, + store_rstd=bool(args.store_rstd), + ) + row: Dict[str, Any] = { + "M": M, + "N": N, + "dtype": args.dtype, + "weight_dtype": args.weight_dtype, + "eps": eps, + "store_rstd": bool(args.store_rstd), + "ours_ms": ms_oink, + "ours_gbps": gbps_oink, + "ours_tbps": gbps_oink / 1000.0, + "ours_hbm_frac": gbps_oink / hbm_peak, + } + if quack is not None: + ms_q, gbps_q = quack + row.update( + { + "quack_ms": ms_q, + "quack_gbps": gbps_q, + "quack_tbps": gbps_q / 1000.0, + "speedup_vs_quack": ms_q / ms_oink, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "(2*M*N)*elem_size + N*weight_elem_size", + "store_rstd": bool(args.store_rstd), + "weight_dtype": str(args.weight_dtype), + }, + ) + + # Print a compact summary table. + headers = ["M", "N", "ours_ms", "ours_tbps"] + if quack_rmsnorm_fwd is not None: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py new file mode 100644 index 0000000..a5b2b3c --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py @@ -0,0 +1,341 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import softmax as oink_softmax # noqa: E402 + +try: + from quack.softmax import softmax_bwd as quack_softmax_bwd # type: ignore + from quack.softmax import softmax_fwd as quack_softmax_fwd # type: ignore +except Exception: + quack_softmax_fwd = None + quack_softmax_bwd = None + +_VERIFY_TOL = { + # Match Quack's unit-test defaults (tests/test_softmax.py). + torch.float32: dict(atol=1e-4, rtol=1e-4), + torch.float16: dict(atol=1e-3, rtol=1e-3), + torch.bfloat16: dict(atol=1e-2, rtol=1e-2), +} + + +def bytes_io_model_softmax(M: int, N: int, dtype: torch.dtype, *, mode: str) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + if mode == "fwd": + return int(2 * M * N * elem) # read x + write y + if mode == "bwd": + return int(3 * M * N * elem) # read dy + read y + write dx + if mode == "fwd_bwd": + # Logical IO for dx given (x, dy): read x + read dy + write dx. + # (The intermediate y=softmax(x) is an implementation detail and is + # intentionally not counted here.) + return int(3 * M * N * elem) + raise ValueError(f"Unsupported mode: {mode}") + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity(x: torch.Tensor) -> dict[str, object]: + tol = _VERIFY_TOL[x.dtype] + ref_block_rows = 4096 + dy = torch.randn_like(x) # upstream grad + + with torch.no_grad(): + y_o = oink_softmax.softmax_forward(x) + dx_o = oink_softmax.softmax_backward(dy, y_o) + dx_fused_o = oink_softmax.softmax_fwd_bwd(dy, x) + + y_q = None + dx_q = None + if quack_softmax_fwd is not None and quack_softmax_bwd is not None: + y_q = quack_softmax_fwd(x) + dx_q = quack_softmax_bwd(dy, y_q) + + M = int(x.shape[0]) + N = int(x.shape[1]) + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if (quack_softmax_fwd is not None and quack_softmax_bwd is not None) + else None + ) + dx_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if (quack_softmax_fwd is not None and quack_softmax_bwd is not None) + else None + ) + + # Match Quack tests: compare to PyTorch softmax refs (fwd+bwd), chunked. + for start, end in iter_row_blocks(M, ref_block_rows): + x_blk = x[start:end] + dy_blk = dy[start:end] + y_ref_blk = torch.softmax(x_blk, dim=-1) + dot = torch.sum(dy_blk * y_ref_blk, dim=-1, keepdim=True, dtype=torch.float32) + dx_ref_blk = (dy_blk - dot.to(dy_blk.dtype)) * y_ref_blk + + torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol) + torch.testing.assert_close(dx_o[start:end], dx_ref_blk, **tol) + torch.testing.assert_close(dx_fused_o[start:end], dx_ref_blk, **tol) + y_acc_ours.update(y_o[start:end], y_ref_blk) + dx_acc_ours.update(dx_o[start:end], dx_ref_blk) + dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref_blk) + if y_q is not None and dx_q is not None: + torch.testing.assert_close(y_q[start:end], y_ref_blk, **tol) + torch.testing.assert_close(dx_q[start:end], dx_ref_blk, **tol) + assert y_acc_quack is not None and dx_acc_quack is not None + y_acc_quack.update(y_q[start:end], y_ref_blk) + dx_acc_quack.update(dx_q[start:end], dx_ref_blk) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize())) + if y_acc_quack is not None and dx_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) + stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize())) + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + warmup_ms: int, + iters_ms: int, + mode: str, + verify: bool, +) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + dy = torch.randn_like(x) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x) + + bytes_io = bytes_io_model_softmax(M, N, dtype, mode=mode) + + if mode == "fwd": + + def fn_oink(): + return oink_softmax.softmax_forward(x) + + fn_quack = None + if quack_softmax_fwd is not None: + + def fn_quack(): + return quack_softmax_fwd(x) + + elif mode == "bwd": + with torch.no_grad(): + y_o = oink_softmax.softmax_forward(x) + y_q = quack_softmax_fwd(x) if quack_softmax_fwd is not None else None + + def fn_oink(): + return oink_softmax.softmax_backward(dy, y_o) + + fn_quack = None + if quack_softmax_bwd is not None and y_q is not None: + + def fn_quack(): + return quack_softmax_bwd(dy, y_q) + + elif mode == "fwd_bwd": + + def fn_oink(): + return oink_softmax.softmax_fwd_bwd(dy, x) + + fn_quack = None + if quack_softmax_fwd is not None and quack_softmax_bwd is not None: + + def fn_quack(): + return quack_softmax_bwd(dy, quack_softmax_fwd(x)) + + else: + raise ValueError(f"Unsupported mode: {mode}") + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + if fn_quack is None: + return (ms_oink, gbps_oink), None, stats + + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + p.add_argument( + "--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"] + ) + p.add_argument( + "--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only)." + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument( + "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid" + ) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs PyTorch softmax)", + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for M, N in cfgs: + print( + f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", + flush=True, + ) + (ms_oink, gbps_oink), quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + mode=str(args.mode), + verify=not args.skip_verify, + ) + row: Dict[str, Any] = { + "M": M, + "N": N, + "dtype": args.dtype, + "mode": args.mode, + "ours_ms": ms_oink, + "ours_gbps": gbps_oink, + "ours_tbps": gbps_oink / 1000.0, + "ours_hbm_frac": gbps_oink / hbm_peak, + } + if quack is not None: + ms_q, gbps_q = quack + row.update( + { + "quack_ms": ms_q, + "quack_gbps": gbps_q, + "quack_tbps": gbps_q / 1000.0, + "speedup_vs_quack": ms_q / ms_oink, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "mode-dependent: fwd=2*M*N, bwd=3*M*N, fwd_bwd=3*M*N (all * elem_size; fwd_bwd counts logical x+dy+dx)", + }, + ) + + headers = ["M", "N", "mode", "ours_ms", "ours_tbps"] + if quack_softmax_fwd is not None and quack_softmax_bwd is not None: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg new file mode 100644 index 0000000..96b5b83 --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg @@ -0,0 +1,2259 @@ + + + + + + + + 2026-01-22T03:16:57.722815 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg new file mode 100644 index 0000000..254623e --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg @@ -0,0 +1,2600 @@ + + + + + + + + 2026-01-22T03:17:01.077305 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg new file mode 100644 index 0000000..9db31a5 --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg @@ -0,0 +1,2936 @@ + + + + + + + + 2026-01-22T03:17:06.137573 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg new file mode 100644 index 0000000..c392959 --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg @@ -0,0 +1,1687 @@ + + + + + + + + 2026-01-22T03:17:04.456371 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg new file mode 100644 index 0000000..0d4c1ae --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg @@ -0,0 +1,2600 @@ + + + + + + + + 2026-01-22T03:17:02.768056 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg new file mode 100644 index 0000000..1780d62 --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg @@ -0,0 +1,2580 @@ + + + + + + + + 2026-01-22T03:16:59.406646 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg new file mode 100644 index 0000000..e3bcd46 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg @@ -0,0 +1,2280 @@ + + + + + + + + 2026-01-22T03:17:07.801333 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg new file mode 100644 index 0000000..e5cecac --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg @@ -0,0 +1,2621 @@ + + + + + + + + 2026-01-22T03:17:11.211653 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg new file mode 100644 index 0000000..1575906 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg @@ -0,0 +1,2957 @@ + + + + + + + + 2026-01-22T03:17:16.168483 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg new file mode 100644 index 0000000..66a3075 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg @@ -0,0 +1,1708 @@ + + + + + + + + 2026-01-22T03:17:14.531728 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg new file mode 100644 index 0000000..d87b7b9 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg @@ -0,0 +1,2621 @@ + + + + + + + + 2026-01-22T03:17:12.903096 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg new file mode 100644 index 0000000..5c849b5 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg @@ -0,0 +1,2601 @@ + + + + + + + + 2026-01-22T03:17:09.483028 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py new file mode 100644 index 0000000..88eebdf --- /dev/null +++ b/oink/benchmarks/readme/plot_quack_style_svg.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite +JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`. + +The intent is to match Quack's README visual style: + - 3 horizontal panels (suite-dependent): + - Quack-suite: RMSNorm / Softmax / CrossEntropy + - DSv3 (hidden-size): Fused Add+RMSNorm / Softmax / LayerNorm + - DSv3 (all ops, 4-panel): Fused Add+RMSNorm / Softmax / LayerNorm / CrossEntropy + - DSv3 CrossEntropy: CrossEntropy-only (single panel) + - y-axis: model memory bandwidth (GB/s) derived from an IO model + - x-axis: a small set of labeled (M, N) shape points + - thick lines + markers, dashed y-grid, compact legend + - optional horizontal roofline line (measured STREAM-like HBM peak) + +Example: + python oink/benchmarks/readme/plot_quack_style_svg.py \\ + --in-dir /tmp/kernelagent_oink_sm100_suite_bf16 \\ + --suite quack_suite \\ + --roofline-json /tmp/hbm_roofline_sm100_bf16.json \\ + --out oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg + +For completeness, we can also include LayerNorm as an extra panel (Quack's +own README plot does not include LayerNorm): + python oink/benchmarks/readme/plot_quack_style_svg.py \\ + --in-dir /tmp/kernelagent_oink_sm100_suite_bf16 \\ + --suite quack_suite \\ + --include-layernorm \\ + --roofline-json /tmp/hbm_roofline_sm100_bf16.json \\ + --out oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg + +Note on DSv3 suite: +- The DSv3 plot intentionally covers only the hidden-size ops (fused Add+RMSNorm, + Softmax, LayerNorm) which share the same `(M, N)` sweep. +- CrossEntropy in DSv3 uses a vocab-size-like `N` sweep and is plotted separately + via `--suite dsv3_cross_entropy` to avoid a mixed x-axis with gaps. +- For README embedding convenience, `--suite dsv3_all` renders a 4-panel + single-row figure where the CrossEntropy panel uses its own x-axis. +- The RMSNorm panel uses the real block primitive (fused residual-add + RMSNorm) + when available: `fused_add_rmsnorm_dsv3.json`. +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +from collections import defaultdict +from statistics import median +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple + + +def _load_json(path: str) -> Dict[str, Any]: + with open(path) as f: + return json.load(f) + + +def _fmt_k(v: int) -> str: + # Match Quack's x-axis labels: "32K" means 32768 (1024-based). + if v % 1024 == 0: + return f"{v // 1024}K" + return str(v) + + +def _shape_label(m: int, n: int) -> str: + return f"({_fmt_k(m)}, {_fmt_k(n)})" + + +def _gbps_from_row(prefix: str, row: Mapping[str, Any]) -> Optional[float]: + # Prefer GB/s in the JSON if present; otherwise fall back to TB/s. + gbps_key = f"{prefix}_gbps" + tbps_key = f"{prefix}_tbps" + if gbps_key in row and row[gbps_key] is not None: + return float(row[gbps_key]) + if tbps_key in row and row[tbps_key] is not None: + return float(row[tbps_key]) * 1000.0 + return None + + +def _aggregate_by_shape( + rows: Sequence[Mapping[str, Any]], +) -> Dict[Tuple[int, int], Dict[str, float]]: + """Aggregate duplicate (M, N) rows using median (more robust than mean).""" + buckets: dict[tuple[int, int], dict[str, list[float]]] = defaultdict( + lambda: defaultdict(list) + ) + for r in rows: + m = int(r["M"]) + n = int(r["N"]) + ours = _gbps_from_row("ours", r) + quack = _gbps_from_row("quack", r) + if ours is not None: + buckets[(m, n)]["ours"].append(ours) + if quack is not None: + buckets[(m, n)]["quack"].append(quack) + + out: Dict[Tuple[int, int], Dict[str, float]] = {} + for k, vs in buckets.items(): + if not vs["ours"] or not vs["quack"]: + continue + out[k] = dict(ours=float(median(vs["ours"])), quack=float(median(vs["quack"]))) + return out + + +def _sort_shapes(shapes: Iterable[Tuple[int, int]]) -> List[Tuple[int, int]]: + # Sort by N then M to keep the x-axis stable across panels. + return sorted(set(shapes), key=lambda x: (x[1], x[0])) + + +def _read_roofline_gbps(path: str) -> float: + payload = _load_json(path) + rows = payload.get("rows", []) + best_tbps = max(float(r["tbps"]) for r in rows) + return best_tbps * 1000.0 + + +def _ensure_matplotlib(): + try: + import matplotlib as mpl # noqa: F401 + import matplotlib.pyplot as plt # noqa: F401 + except Exception as e: # pragma: no cover + raise SystemExit( + "matplotlib is required to generate SVG plots.\n" + "Install with: `python -m pip install matplotlib`" + ) from e + + +def _plot( + *, + panels: Sequence[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]], + roofline_gbps: Optional[float], + out_path: str, + title: str, + shape_policy: str, + per_panel_x: bool, +) -> None: + _ensure_matplotlib() + import matplotlib as mpl + import matplotlib.pyplot as plt + + mpl.rcParams.update( + { + # Quack-style: embed glyphs as paths for consistent rendering. + "svg.fonttype": "path", + "font.family": "DejaVu Sans", + "axes.titlesize": 18, + "axes.labelsize": 16, + "xtick.labelsize": 10, + "ytick.labelsize": 12, + } + ) + + # Colors roughly matching Quack's SVG palette. + COLOR_OINK = "#5ba3f5" + COLOR_QUACK = "#ff4444" + COLOR_ROOF = "#4d4d4d" + + fig, axes = plt.subplots( + nrows=1, + ncols=len(panels), + figsize=(6.0 * len(panels), 5.6), + constrained_layout=False, + sharey=True, + ) + if len(panels) == 1: + axes = [axes] + + max_y = 0.0 + for ax, (panel_title, data) in zip(axes, panels): + if per_panel_x: + shapes = _sort_shapes(data.keys()) + else: + # Quack-style plots use a single shared x-axis across panels. Prefer + # the intersection so every panel has a value at every x tick + # (cleaner than rendering gaps), and fall back to the union if the + # intersection is empty. + shape_sets = [set(d.keys()) for _n, d in panels] + if shape_policy in {"first", "primary"}: + shapes = _sort_shapes(shape_sets[0]) if shape_sets else [] + elif shape_policy == "intersection" and shape_sets: + common = set.intersection(*shape_sets) + shapes = _sort_shapes(common) if common else [] + elif shape_policy == "union": + shapes = _sort_shapes(s for _n, d in panels for s in d.keys()) + else: + raise ValueError(f"Unsupported shape_policy: {shape_policy}") + if not shapes: + shapes = _sort_shapes(s for _n, d in panels for s in d.keys()) + + x = list(range(len(shapes))) + x_labels = [_shape_label(m, n) for (m, n) in shapes] + + ours_y: List[float] = [] + quack_y: List[float] = [] + for s in shapes: + rec = data.get(s) + if rec is None: # only possible in shared-x mode with union + ours_y.append(math.nan) + quack_y.append(math.nan) + continue + ours_y.append(float(rec["ours"])) + quack_y.append(float(rec["quack"])) + max_y = max( + max_y, + *(v for v in ours_y if math.isfinite(v)), + *(v for v in quack_y if math.isfinite(v)), + ) + + ax.plot( + x, + ours_y, + marker="o", + linewidth=5, + markersize=7, + color=COLOR_OINK, + label="KernelAgent-Oink (ours)", + ) + ax.plot( + x, + quack_y, + marker="o", + linewidth=5, + markersize=7, + color=COLOR_QUACK, + label="Quack", + ) + if roofline_gbps is not None: + ax.axhline( + roofline_gbps, + color=COLOR_ROOF, + linewidth=3, + linestyle=(0, (4, 6)), + label="HBM peak (measured)" if ax is axes[0] else None, + ) + max_y = max(max_y, float(roofline_gbps)) + + ax.set_title(panel_title) + ax.set_xticks(x) + ax.set_xticklabels(x_labels, rotation=-45, ha="left") + if per_panel_x: + # DSv3 "all ops" figure: each panel has its own x-axis. Make the + # semantics explicit so readers don't assume the same `N` meaning + # across panels (CrossEntropy uses a classes/vocab-shard-like axis). + if "cross" in panel_title.lower(): + ax.set_xlabel("Shape (M, C classes)") + else: + ax.set_xlabel("Shape (M, N hidden)") + + # Quack-like dashed y-grid. + ax.grid(axis="y", linestyle=(0, (4, 7.2)), linewidth=0.8, color="#b0b0b0") + ax.set_axisbelow(True) + + # Light spines (Quack SVG uses a light gray frame). + for spine in ax.spines.values(): + spine.set_color("#d3d3d3") + spine.set_linewidth(1.5) + + axes[0].set_ylabel("Memory Bandwidth (GB/s)") + + # A little headroom above the tallest curve/roofline. + ymax = max_y * 1.08 if max_y > 0 else 1.0 + for ax in axes: + ax.set_ylim(0.0, ymax) + + # Tight layout for the axes area, reserving headroom for the suptitle and a + # shared legend. In some matplotlib versions, figure-level legends can + # overlap the middle panel title unless we reserve a slightly taller header + # band. + fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.70)) + + # Single shared legend across the top (like Quack), but keep it inside the + # reserved header band so it doesn't overlap the middle panel title. + handles, labels = axes[0].get_legend_handles_labels() + # Quack's legend fits nicely in one row because their plots are 3-panel and + # therefore wide. For single-panel figures, a 3-column legend can overflow + # the canvas and get clipped in the SVG, so we stack it vertically. + legend_ncol = min(3, len(labels)) + legend_fontsize = 13 + if len(panels) == 1: + legend_ncol = 1 + legend_fontsize = 12 + fig.legend( + handles, + labels, + loc="upper center", + ncol=legend_ncol, + frameon=False, + bbox_to_anchor=(0.5, 0.91), + fontsize=legend_fontsize, + handlelength=2.5, + ) + # Single-panel figures (e.g. DSv3 CrossEntropy) are much narrower than the + # Quack-style 3-panel plots; use a slightly smaller suptitle font to avoid + # clipping in the exported SVG. + suptitle_fs = 22 if len(panels) > 1 else 18 + fig.suptitle(title, y=0.98, fontsize=suptitle_fs) + + out_path = os.path.abspath(out_path) + os.makedirs(os.path.dirname(out_path), exist_ok=True) + # Use a tight bounding box so rotated x tick labels and the figure-level + # legend don't get clipped in SVG exports (matplotlib can be fragile here + # across versions). + fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.02) + plt.close(fig) + + +def _panel_files_for_suite(suite: str) -> List[Tuple[str, str]]: + if suite == "quack_suite": + return [ + ("RMSNorm (fp32 weight)", "rmsnorm_fwd_quack_suite_wfp32.json"), + ("Softmax (fwd+bwd)", "softmax_fwd_bwd_quack_suite.json"), + ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_quack_suite.json"), + ] + if suite == "dsv3": + return [ + ("Fused Add+RMSNorm (fwd)", "fused_add_rmsnorm_dsv3.json"), + ("Softmax (fwd+bwd)", "softmax_fwd_bwd_dsv3.json"), + ("LayerNorm (fwd)", "layernorm_fwd_dsv3.json"), + ] + if suite == "dsv3_all": + return [ + ("Fused Add+RMSNorm (fwd)", "fused_add_rmsnorm_dsv3.json"), + ("Softmax (fwd+bwd)", "softmax_fwd_bwd_dsv3.json"), + ("LayerNorm (fwd)", "layernorm_fwd_dsv3.json"), + ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_dsv3.json"), + ] + if suite == "dsv3_cross_entropy": + return [ + ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_dsv3.json"), + ] + raise ValueError(f"Unsupported suite: {suite}") + + +def _layernorm_file_for_suite(suite: str) -> str: + if suite == "quack_suite": + return "layernorm_fwd_quack_suite.json" + raise ValueError(f"Unsupported suite: {suite}") + + +def main() -> None: + p = argparse.ArgumentParser( + description="Generate Quack-style SVG plots from KernelAgent-Oink suite JSONs." + ) + p.add_argument( + "--in-dir", + type=str, + required=True, + help="Directory containing suite JSON outputs", + ) + p.add_argument( + "--suite", + type=str, + default="quack_suite", + choices=["quack_suite", "dsv3", "dsv3_all", "dsv3_cross_entropy"], + ) + p.add_argument( + "--include-layernorm", + action="store_true", + help="Add a LayerNorm (fwd) panel (only meaningful for `--suite quack_suite`).", + ) + p.add_argument( + "--shape-policy", + type=str, + default="intersection", + choices=["intersection", "union", "first"], + help=( + "How to pick x-axis shapes across panels. " + "`intersection` matches Quack-style (only shapes common to every panel). " + "`first` uses the first panel's shapes (keeps DSv3 N=7168 visible). " + "`union` includes every shape across panels (may create gaps)." + ), + ) + p.add_argument( + "--roofline-json", + type=str, + default=None, + help="Optional /tmp/hbm_roofline_sm100_*.json path", + ) + p.add_argument("--out", type=str, required=True, help="Output SVG path") + p.add_argument( + "--title", type=str, default=None, help="Optional figure title override" + ) + args = p.parse_args() + + in_dir = os.path.abspath(args.in_dir) + if not os.path.isdir(in_dir): + raise SystemExit(f"--in-dir is not a directory: {in_dir}") + + roofline_gbps = ( + _read_roofline_gbps(args.roofline_json) if args.roofline_json else None + ) + + panel_files = list(_panel_files_for_suite(str(args.suite))) + if args.include_layernorm: + if args.suite != "quack_suite": + raise SystemExit( + "--include-layernorm is only supported for `--suite quack_suite`." + ) + panel_files.append( + ("LayerNorm (fwd)", _layernorm_file_for_suite(str(args.suite))) + ) + + panels: List[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]] = [] + for panel_title, filename in panel_files: + path = os.path.join(in_dir, filename) + if not os.path.exists(path): + raise SystemExit(f"Missing required JSON: {path}") + payload = _load_json(path) + rows = payload.get("rows", []) + if not isinstance(rows, list): + rows = [] + panels.append((panel_title, _aggregate_by_shape(rows))) + + if args.title is not None: + title = str(args.title) + else: + # Try to infer dtype from the first panel's JSON. + first_json = os.path.join(in_dir, panel_files[0][1]) + payload = _load_json(first_json) + rows = payload.get("rows", []) + dtype = rows[0].get("dtype", "") if rows else "" + if args.suite == "quack_suite": + suite_name = "Quack-suite" + elif args.suite == "dsv3": + suite_name = "DSv3 (hidden-size ops)" + elif args.suite == "dsv3_all": + suite_name = "DSv3 (4 ops)" + elif args.suite == "dsv3_cross_entropy": + # Keep this short: this suite is rendered as a single panel, so the + # figure is much narrower than the 3-panel plots. + suite_name = "DSv3 CrossEntropy" + else: + suite_name = str(args.suite) + suffix = ( + " (+LayerNorm)" + if (args.suite == "quack_suite" and args.include_layernorm) + else "" + ) + if args.suite == "dsv3_cross_entropy": + title = f"SM100 {dtype.upper()} — {suite_name}{suffix}" + else: + title = f"SM100 {dtype.upper()} Kernel Benchmarks (Oink vs Quack) — {suite_name}{suffix}" + + _plot( + panels=panels, + roofline_gbps=roofline_gbps, + out_path=str(args.out), + title=title, + shape_policy=str(args.shape_policy), + per_panel_x=(str(args.suite) == "dsv3_all"), + ) + print(f"Wrote: {os.path.abspath(args.out)}") + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py new file mode 100644 index 0000000..af33e38 --- /dev/null +++ b/oink/benchmarks/readme/run_sm100_suite.py @@ -0,0 +1,337 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import os +import subprocess +import sys +from datetime import datetime +from typing import List, Tuple + + +def _ts() -> str: + return datetime.now().strftime("%Y%m%d_%H%M%S") + + +def _run(cmd: List[str], *, dry_run: bool) -> None: + print("+", " ".join(cmd), flush=True) + if dry_run: + return + subprocess.run(cmd, check=True) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + p.add_argument( + "--out-dir", + type=str, + default=None, + help="Directory to write JSON outputs (default: /tmp/kernelagent_oink_sm100_suite_)", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs PyTorch / pure-PyTorch references)", + ) + p.add_argument( + "--dry-run", action="store_true", help="Print commands without executing them" + ) + args = p.parse_args() + + # Standardize env for standalone runs outside the vLLM plugin. + os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + + out_dir = args.out_dir or f"/tmp/kernelagent_oink_sm100_suite_{_ts()}" + os.makedirs(out_dir, exist_ok=True) + + here = os.path.dirname(os.path.abspath(__file__)) + bench_dir = os.path.abspath(os.path.join(here, "..", "benchmark")) + py = sys.executable + + def script(name: str) -> str: + return os.path.join(bench_dir, name) + + common = ["--dtype", args.dtype] + if args.skip_verify: + common = [*common, "--skip-verify"] + + runs: List[Tuple[str, List[str]]] = [ + ( + "rmsnorm_fwd_quack_suite_wfp32", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wfp32.json"), + ], + ), + ( + "rmsnorm_fwd_dsv3_wfp32", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_dsv3_wfp32.json"), + ], + ), + ( + "rmsnorm_bwd_quack_suite_wfp32", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--quack-suite", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wfp32.json"), + ], + ), + ( + "rmsnorm_bwd_dsv3_wfp32", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--dsv3", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_dsv3_wfp32.json"), + ], + ), + # vLLM inference-style RMSNorm (weight dtype == activation dtype). + ( + "rmsnorm_fwd_quack_suite_wsame", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "same", + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wsame.json"), + ], + ), + ( + "rmsnorm_fwd_dsv3_wsame", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "same", + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_dsv3_wsame.json"), + ], + ), + ( + "rmsnorm_bwd_quack_suite_wsame", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "same", + "--quack-suite", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wsame.json"), + ], + ), + ( + "rmsnorm_bwd_dsv3_wsame", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "same", + "--dsv3", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"), + ], + ), + ( + "fused_add_rmsnorm_dsv3", + [ + py, + script("benchmark_fused_add_rmsnorm_sm100.py"), + *common, + "--dsv3", + "--quack-baseline", + "kernel_inplace", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "fused_add_rmsnorm_dsv3.json"), + ], + ), + ( + "softmax_fwd_bwd_quack_suite", + [ + py, + script("benchmark_softmax_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--quack-suite", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "softmax_fwd_bwd_quack_suite.json"), + ], + ), + ( + "softmax_fwd_bwd_dsv3", + [ + py, + script("benchmark_softmax_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--dsv3", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "softmax_fwd_bwd_dsv3.json"), + ], + ), + ( + "cross_entropy_fwd_bwd_quack_suite", + [ + py, + script("benchmark_cross_entropy_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--quack-suite", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "cross_entropy_fwd_bwd_quack_suite.json"), + ], + ), + ( + "cross_entropy_fwd_bwd_dsv3", + [ + py, + script("benchmark_cross_entropy_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--dsv3", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "cross_entropy_fwd_bwd_dsv3.json"), + ], + ), + ( + "layernorm_fwd_quack_suite", + [ + py, + script("benchmark_layernorm_sm100.py"), + *common, + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "layernorm_fwd_quack_suite.json"), + ], + ), + ( + "layernorm_fwd_dsv3", + [ + py, + script("benchmark_layernorm_sm100.py"), + *common, + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "layernorm_fwd_dsv3.json"), + ], + ), + ] + + print(f"Writing results to: {out_dir}", flush=True) + for name, cmd in runs: + print(f"\n== {name} ==", flush=True) + _run(cmd, dry_run=bool(args.dry_run)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py new file mode 100644 index 0000000..684694d --- /dev/null +++ b/oink/benchmarks/readme/summarize_results.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import json +import math +import os +from typing import Any, Dict, Iterable, List, Optional, Sequence + + +def _load_json(path: str) -> Dict[str, Any]: + with open(path) as f: + return json.load(f) + + +def _fmt_cell(v: object) -> str: + if v is None: + return "" + if isinstance(v, float): + if math.isfinite(v): + av = abs(v) + # Use scientific notation for very small values so we don't render + # meaningful error stats as "0.0000". + if av != 0.0 and av < 1e-3: + return f"{v:.2e}" + return f"{v:.4f}" + return str(v) + return str(v) + + +def _md_table(rows: Sequence[Dict[str, Any]], columns: Sequence[str]) -> str: + header = "| " + " | ".join(columns) + " |" + sep = "|" + "|".join(["---"] * len(columns)) + "|" + lines = [header, sep] + for r in rows: + lines.append("| " + " | ".join(_fmt_cell(r.get(c)) for c in columns) + " |") + return "\n".join(lines) + + +def _pick_columns(rows: Sequence[Dict[str, Any]]) -> List[str]: + preferred = [ + "M", + "N", + "dtype", + "weight_dtype", + "mode", + "eps", + "store_rstd", + "return_rstd", + "return_mean", + "ignore_index", + "ours_ms", + "ours_tbps", + "ours_hbm_frac", + "quack_ms", + "quack_tbps", + "speedup_vs_quack", + ] + present = set().union(*(r.keys() for r in rows)) if rows else set() + cols = [c for c in preferred if c in present] + # Fall back to a stable sorted view if we missed everything (shouldn't happen). + return cols or sorted(present) + + +def _geomean(values: Iterable[float]) -> Optional[float]: + logs: List[float] = [] + for v in values: + if v <= 0 or not math.isfinite(v): + continue + logs.append(math.log(v)) + if not logs: + return None + return math.exp(sum(logs) / len(logs)) + + +def _collect_error_prefixes(rows: Sequence[Dict[str, Any]]) -> List[str]: + """Infer error-stat prefixes like `ours_err_dx` from row keys.""" + prefixes: set[str] = set() + for r in rows: + for k in r.keys(): + if not isinstance(k, str): + continue + if not k.endswith("_max_abs"): + continue + if "err_" not in k: + continue + prefixes.add(k[: -len("_max_abs")]) + return sorted(prefixes) + + +def _summarize_error_stats(rows: Sequence[Dict[str, Any]]) -> str: + prefixes = _collect_error_prefixes(rows) + if not prefixes: + return "" + + out_rows: List[Dict[str, Any]] = [] + for pfx in prefixes: + # Per-prefix worst-case across rows. + max_abs_vals = [ + float(r[pfx + "_max_abs"]) for r in rows if (pfx + "_max_abs") in r + ] + p99_abs_vals = [ + float(r[pfx + "_p99_abs"]) for r in rows if (pfx + "_p99_abs") in r + ] + rel_l2_vals = [ + float(r[pfx + "_rel_l2"]) for r in rows if (pfx + "_rel_l2") in r + ] + if not max_abs_vals and not p99_abs_vals and not rel_l2_vals: + continue + out_rows.append( + { + "metric": pfx, + "max_abs (max over shapes)": max(max_abs_vals) + if max_abs_vals + else None, + "p99_abs (max over shapes)": max(p99_abs_vals) + if p99_abs_vals + else None, + "rel_l2 (max over shapes)": max(rel_l2_vals) if rel_l2_vals else None, + } + ) + + if not out_rows: + return "" + + cols = [ + "metric", + "max_abs (max over shapes)", + "p99_abs (max over shapes)", + "rel_l2 (max over shapes)", + ] + return "\n".join( + ["", "### Error Stats (vs PyTorch ref)", "", _md_table(out_rows, cols), ""] + ) + + +def summarize_one(path: str) -> str: + payload = _load_json(path) + meta = payload.get("meta", {}) + rows = payload.get("rows", []) + if not isinstance(rows, list): + rows = [] + + cols = _pick_columns(rows) + parts: List[str] = [] + + base = os.path.basename(path) + parts.append(f"## `{base}`") + if meta: + device = meta.get("device") + cap = meta.get("capability") + torch_ver = meta.get("torch") + cuda_ver = meta.get("cuda") + git_sha = meta.get("git_sha") + ts = meta.get("timestamp") + parts.append("") + parts.append( + f"- device: `{device}` | capability: `{cap}` | torch: `{torch_ver}` | cuda: `{cuda_ver}` | git_sha: `{git_sha}` | timestamp: `{ts}`" + ) + method = meta.get("method") + if method is not None: + parts.append(f"- method: `{method}`") + if meta.get("warmup_ms") is not None and meta.get("rep_ms") is not None: + parts.append( + f"- warmup_ms: `{meta.get('warmup_ms')}` | rep_ms: `{meta.get('rep_ms')}`" + ) + + if rows: + parts.append("") + parts.append(_md_table(rows, cols)) + + speeds = [float(r["speedup_vs_quack"]) for r in rows if "speedup_vs_quack" in r] + gm = _geomean(speeds) + if gm is not None: + parts.append("") + parts.append( + f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)" + ) + + err_block = _summarize_error_stats(rows) + if err_block: + parts.append(err_block.rstrip()) + else: + parts.append("") + parts.append("_No rows found in JSON._") + + parts.append("") + return "\n".join(parts) + + +def main() -> None: + p = argparse.ArgumentParser( + description="Summarize KernelAgent-Oink benchmark JSONs into Markdown tables." + ) + p.add_argument( + "--in-dir", + type=str, + required=True, + help="Directory containing benchmark JSON files", + ) + p.add_argument( + "--out", + type=str, + default=None, + help="Optional output markdown path (default: stdout)", + ) + args = p.parse_args() + + in_dir = os.path.abspath(args.in_dir) + if not os.path.isdir(in_dir): + raise SystemExit(f"--in-dir is not a directory: {in_dir}") + + json_paths = sorted( + os.path.join(in_dir, name) + for name in os.listdir(in_dir) + if name.endswith(".json") + ) + if not json_paths: + raise SystemExit(f"No .json files found under: {in_dir}") + + out_parts: List[str] = [] + out_parts.append("# KernelAgent-Oink SM100 Benchmark Summary") + out_parts.append("") + out_parts.append(f"Input directory: `{in_dir}`") + out_parts.append("") + for path in json_paths: + out_parts.append(summarize_one(path)) + + text = "\n".join(out_parts).rstrip() + "\n" + if args.out is None: + print(text, end="") + return + + out_path = os.path.abspath(args.out) + os.makedirs(os.path.dirname(out_path), exist_ok=True) + with open(out_path, "w") as f: + f.write(text) + + +if __name__ == "__main__": + main() diff --git a/oink/pyproject.toml b/oink/pyproject.toml new file mode 100644 index 0000000..0d19d6e --- /dev/null +++ b/oink/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "kernelagent-oink" +version = "0.1.0" +description = "CuTeDSL kernels for Blackwell (SM100), shipped as a vLLM plugin" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "Apache-2.0"} +authors = [{name = "PyTorch Labs"}] +keywords = ["cuda", "cutlass", "cute", "cutedsl", "blackwell", "sm100", "vllm"] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +[project.urls] +Repository = "https://github.com/meta-pytorch/KernelAgent" +Documentation = "https://github.com/meta-pytorch/KernelAgent/tree/main/oink" +Issues = "https://github.com/meta-pytorch/KernelAgent/issues" + +# Keep dependencies minimal, but include the CuTeDSL stack required by the +# Blackwell RMSNorm implementation. +# +# We intentionally do NOT depend on `torch` here because vLLM already pins and +# provides a compatible PyTorch build. +dependencies = [ + "nvidia-cutlass-dsl", + "cuda-python", +] + +[project.optional-dependencies] +# Optional extras for running the in-repo benchmark suite (not needed for vLLM integration). +bench = [ + "matplotlib", + "triton", +] + +[project.entry-points."vllm.general_plugins"] +oink = "kernelagent_oink:register" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["kernelagent_oink*"] diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py new file mode 100644 index 0000000..f5c36a6 --- /dev/null +++ b/oink/src/kernelagent_oink/__init__.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +KernelAgent-Oink: SM100 CuTeDSL kernels + optional vLLM plugin. + +This package can be loaded as a vLLM "general plugin" (entrypoint group +`vllm.general_plugins`). In that mode it registers Oink custom ops only when +explicitly enabled via an environment variable (so installing the package does +not change behavior by default). + +For standalone usage (outside vLLM), call `kernelagent_oink.register(force=True)` +to register the custom ops explicitly. +""" + +from __future__ import annotations + +import logging +import os + +logger = logging.getLogger(__name__) + +_OPS_REGISTERED = False + + +def _env_truthy(name: str) -> bool: + val = os.environ.get(name) + if val is None: + return False + return val.strip().lower() in ("1", "true", "yes", "on") + + +def _infer_cuda_device_index() -> int: + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + try: + return int(local_rank) + except ValueError: + pass + return 0 + + +def _compute_cutedsl_arch(major: int, minor: int) -> str: + # CuTeDSL uses an "a" suffix for >= Hopper. + suffix = "a" if major >= 9 else "" + # Match cutlass/base_dsl/env_manager.py: map sm_110 -> sm_101. + if major == 11 and minor == 0: + major, minor = 10, 1 + return f"sm_{major}{minor}{suffix}" + + +def register(*, force: bool = False) -> None: + """Register Oink torch custom ops. + + - vLLM plugin mode (default): no-op unless `VLLM_USE_OINK_RMSNORM` is truthy. + - Standalone mode: pass `force=True` to register explicitly. + + This function must be safe to call multiple times and must not raise. vLLM + executes it in multiple processes (engine + workers). + """ + global _OPS_REGISTERED + + if _OPS_REGISTERED: + return + + # Gate on the vLLM integration flag so installing the package does not + # change behavior unless explicitly enabled. For standalone usage (outside + # vLLM), callers can pass force=True to register the ops explicitly. + if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"): + return + + try: + import torch + except Exception as e: # pragma: no cover + logger.debug("Oink plugin: torch import failed: %s", e) + return + + try: + if not torch.cuda.is_available(): + return + device_index = _infer_cuda_device_index() + major, minor = torch.cuda.get_device_capability(device_index) + sm = 10 * int(major) + int(minor) + if sm < 100: + return + + # Ensure required deps are importable before registering ops so that vLLM + # doesn't detect ops that would later fail at first use. + try: + import cutlass # noqa: F401 + import cuda.bindings.driver as _cuda # noqa: F401 + except Exception as e: + logger.warning( + "Oink plugin: CuTeDSL deps missing; skipping op registration. " + "Install `nvidia-cutlass-dsl` + `cuda-python`. Error: %s", + e, + ) + return + + # Ensure CuTeDSL sees a target arch early. If the user has already set it, + # respect their choice. + os.environ.setdefault( + "CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor)) + ) + + # Import registers the ops via torch.library.custom_op decorators. + from .blackwell import oink_custom_ops # noqa: F401 + except Exception as e: # pragma: no cover + # Do not raise: vLLM plugin loader does not guard plugin execution. + logger.exception("Oink plugin: failed to register ops: %s", e) + return + + _OPS_REGISTERED = True + + +__all__ = ["register"] diff --git a/oink/src/kernelagent_oink/blackwell/__init__.py b/oink/src/kernelagent_oink/blackwell/__init__.py new file mode 100644 index 0000000..a92109a --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +__all__ = [] diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py new file mode 100644 index 0000000..d8b37ea --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py @@ -0,0 +1,2257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cross-entropy forward + backward kernels for SM100 (Blackwell) in CuteDSL. + +This module implements numerically stable cross-entropy over the last +dimension of 2D logits tensors `(M, N)` together with its backward pass, +targeting SM100 with Quack-style tiling, cp.async pipelines, and (for the +forward pass) optional cluster-wide online softmax reductions, but without +depending on the external `quack` package at runtime. + +Public APIs: + +- ``cross_entropy_forward(logits, target, ignore_index=-100, reduction="none")`` + returns ``(loss, lse)`` where ``loss`` follows the requested reduction and + ``lse`` is always per-example log-sum-exp (shape ``(M,)``). +- ``cross_entropy_backward(dloss, logits, target, lse, ignore_index=-100)`` + returns per-logit gradients ``dlogits`` matching PyTorch / + ``quack.cross_entropy_bwd`` semantics for ``reduction="none"``. +- ``cross_entropy(logits, target, ignore_index=-100, reduction="mean"|"sum"|"none")`` + is a convenience wrapper that mirrors ``torch.nn.functional.cross_entropy`` + reductions using the SM100 CuteDSL kernels for the forward pass. + +The kernels are self-contained and use only local helpers in +`kernelagent_oink.blackwell.lite_quack` plus CuTeDSL/CUTLASS. +""" + +from __future__ import annotations + +import importlib.metadata +import math +import os +import re +from typing import Literal, Optional, Type + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +# CuTeDSL caches generated MLIR into a tempdir under a global default +# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across +# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy +# warnings (and disables cache reuse). +if "CUTE_DSL_CACHE_DIR" not in os.environ: + try: + _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + _dsl_ver = "unknown" + _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) + _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + _tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" + ) + +try: + import cutlass # type: ignore # noqa: F401 +except Exception as e: + raise ImportError( + "kernelagent_oink.blackwell.cross_entropy requires CuTeDSL's Python package " + "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." + ) from e + +import cutlass.cute as cute +from cutlass import Boolean, Float32, Int32, const_expr +from cutlass.cute import runtime as rt +from cutlass.cute.runtime import from_dlpack + +from kernelagent_oink.blackwell.fast_launch import ( + StableI32Arg, + disable_fast_launch, + fast_launch_enabled, + set_runtime_ptr, + tls_cache as _tls_fast_launch_cache, +) +from kernelagent_oink.blackwell.lite_quack import ( + _KERNEL_ACCEPTS_LAYOUT_ARGS, + TORCH2CUTE_DTYPE, + ReductionBase, + fill_oob, + online_softmax_reduce, + predicate_k, +) + +_FWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {} +_BWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {} +_PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_PTR_FWDBWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +class _PtrCrossEntropyFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_logits: object, + ptr_target: object, + ptr_aux_a: object, + ptr_aux_b: object, + ptr_aux_c: object | None, + arg_m: StableI32Arg, + arg_ld: StableI32Arg, + arg_ignore_index: StableI32Arg, + stream: cuda.CUstream, + packed_args: object, + keepalive: tuple[object, ...], + logits_align: int, + target_align: int, + aux_a_align: int, + aux_b_align: int, + aux_c_align: int | None, + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_logits = ptr_logits + self._ptr_target = ptr_target + self._ptr_aux_a = ptr_aux_a + self._ptr_aux_b = ptr_aux_b + self._ptr_aux_c = ptr_aux_c + self._arg_m = arg_m + self._arg_ld = arg_ld + self._arg_ignore_index = arg_ignore_index + self._stream = stream + self._packed_args = packed_args + self._keepalive = keepalive + self._logits_align = int(logits_align) + self._target_align = int(target_align) + self._aux_a_align = int(aux_a_align) + self._aux_b_align = int(aux_b_align) + self._aux_c_align = int(aux_c_align) if aux_c_align is not None else None + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_logits_ptr = -1 + self._last_target_ptr = -1 + self._last_aux_a_ptr = -1 + self._last_aux_b_ptr = -1 + self._last_aux_c_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_ignore_index = None + + def launch( + self, + *, + logits_ptr: int, + target_ptr: int, + aux_a_ptr: int, + aux_b_ptr: int, + aux_c_ptr: int | None, + M: int, + ld: int, + ignore_index: int, + stream_handle: int, + dtype_logits: type[cutlass.Numeric], + aux_a_dtype: type[cutlass.Numeric], + aux_b_dtype: type[cutlass.Numeric], + aux_c_dtype: type[cutlass.Numeric] | None, + ) -> None: + if not fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if logits_ptr != self._last_logits_ptr: + try: + set_runtime_ptr(self._ptr_logits, logits_ptr) + self._last_logits_ptr = logits_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if target_ptr != self._last_target_ptr: + try: + set_runtime_ptr(self._ptr_target, target_ptr) + self._last_target_ptr = target_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if aux_a_ptr != self._last_aux_a_ptr: + try: + set_runtime_ptr(self._ptr_aux_a, aux_a_ptr) + self._last_aux_a_ptr = aux_a_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if aux_b_ptr != self._last_aux_b_ptr: + try: + set_runtime_ptr(self._ptr_aux_b, aux_b_ptr) + self._last_aux_b_ptr = aux_b_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if self._ptr_aux_c is not None and aux_c_ptr is not None: + if aux_c_ptr != self._last_aux_c_ptr: + try: + set_runtime_ptr(self._ptr_aux_c, aux_c_ptr) + self._last_aux_c_ptr = aux_c_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if ignore_index != self._last_ignore_index: + self._arg_ignore_index.set(ignore_index) + self._last_ignore_index = int(ignore_index) + + if self._cuda_result is not None: + self._cuda_result.value = 0 + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + self._use_fast_launch = False + disable_fast_launch() + + def _fallback_launch( + self, + *, + logits_ptr: int, + target_ptr: int, + aux_a_ptr: int, + aux_b_ptr: int, + aux_c_ptr: int | None, + M: int, + ld: int, + ignore_index: int, + stream_handle: int, + dtype_logits: type[cutlass.Numeric], + aux_a_dtype: type[cutlass.Numeric], + aux_b_dtype: type[cutlass.Numeric], + aux_c_dtype: type[cutlass.Numeric] | None, + ) -> None: + stream = cuda.CUstream(int(stream_handle)) + ptr_logits = rt.make_ptr( + dtype_logits, + int(logits_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._logits_align, + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + int(target_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._target_align, + ) + ptr_aux_a = rt.make_ptr( + aux_a_dtype, + int(aux_a_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._aux_a_align, + ) + ptr_aux_b = rt.make_ptr( + aux_b_dtype, + int(aux_b_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._aux_b_align, + ) + if ( + self._ptr_aux_c is not None + and aux_c_ptr is not None + and aux_c_dtype is not None + ): + ptr_aux_c = rt.make_ptr( + aux_c_dtype, + int(aux_c_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=int(self._aux_c_align or 0), + ) + self._compiled( + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + ptr_aux_c, + Int32(int(M)), + Int32(int(ld)), + Int32(int(ignore_index)), + stream, + ) + else: + self._compiled( + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + Int32(int(M)), + Int32(int(ld)), + Int32(int(ignore_index)), + stream, + ) + + +def _get_fast_ptr_cross_entropy_launcher( + *, + compiled: object, + dtype_logits: type[cutlass.Numeric], + N: int, + device_index: int, + stream_handle: int, + mode: Literal["fwd", "bwd", "fwd_bwd"], +) -> _PtrCrossEntropyFastLaunch | None: + if not fast_launch_enabled(): + return None + key = ( + f"ptr_fast_{mode}", + id(compiled), + int(N), + dtype_logits, + int(device_index), + int(stream_handle), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + ptr_logits = rt.make_ptr( + dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, 0, mem_space=rt.AddressSpace.gmem, assumed_align=8 + ) + if mode == "fwd": + ptr_aux_a = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # loss + ptr_aux_b = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # lse + ptr_aux_c = None + aux_align_b = 4 + aux_align_c = None + elif mode == "bwd": + ptr_aux_a = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # dloss + ptr_aux_b = rt.make_ptr( + dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) # dx + ptr_aux_c = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # lse + aux_align_b = 16 + aux_align_c = 4 + elif mode == "fwd_bwd": + ptr_aux_a = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # dloss + ptr_aux_b = rt.make_ptr( + dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) # dx + ptr_aux_c = None + aux_align_b = 16 + aux_align_c = None + else: + raise ValueError(f"Unsupported mode: {mode}") + + arg_m = StableI32Arg(0) + arg_ld = StableI32Arg(N) + arg_ignore_index = StableI32Arg(-100) + stream = cuda.CUstream(int(stream_handle)) + executor = compiled.to(device_index) # type: ignore[attr-defined] + + try: + if ptr_aux_c is not None: + exe_args, adapted_args = executor.generate_execution_args( + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + ptr_aux_c, + arg_m, + arg_ld, + arg_ignore_index, + stream, + ) + else: + exe_args, adapted_args = executor.generate_execution_args( + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + arg_m, + arg_ld, + arg_ignore_index, + stream, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + disable_fast_launch() + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + ptr_aux_c, + arg_m, + arg_ld, + arg_ignore_index, + stream, + *adapted_args, + ) + launcher = _PtrCrossEntropyFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_logits=ptr_logits, + ptr_target=ptr_target, + ptr_aux_a=ptr_aux_a, + ptr_aux_b=ptr_aux_b, + ptr_aux_c=ptr_aux_c, + arg_m=arg_m, + arg_ld=arg_ld, + arg_ignore_index=arg_ignore_index, + stream=stream, + packed_args=packed_args, + keepalive=keepalive, + logits_align=16, + target_align=8, + aux_a_align=4, + aux_b_align=aux_align_b, + aux_c_align=aux_align_c, + ) + cache[key] = launcher + return launcher + + +def _convert_logits_2d(x: Tensor) -> cute.Tensor: + """Convert a 2D logits tensor (M, N) into a CuTe tensor. + + We assume 16-byte alignment and mark the layout compact and row-major + in the last dimension, matching the conventions used in the SM100 + softmax and RMSNorm kernels. + """ + assert x.dim() == 2, "Input logits must be 2D (M, N)" + return from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic( + mode=0, stride_order=(0, 1) + ) + + +def _convert_1d(t: Tensor, assumed_align: int) -> cute.Tensor: + """Convert a 1D tensor with a fully dynamic layout.""" + assert t.dim() == 1, "Expected a 1D tensor" + return from_dlpack(t.detach(), assumed_align=assumed_align).mark_layout_dynamic() + + +class CrossEntropyFwdSM100(ReductionBase): + """SM100-tuned cross-entropy forward kernel. + + This mirrors the structure of ``quack.cross_entropy.CrossEntropy`` but + is simplified to always use the single-pass online softmax reduction and + never computes gradients inside the forward kernel. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # Use one stage with an Int64 reduction buffer packing (max, sum_exp) + # pairs via lite_quack.online_softmax_reduce. + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + + def _calculate_threads_per_row(self) -> int: + N = self.N + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) + ) + ) + + def _set_cluster_n(self) -> None: + # Match Quack's cluster_n growth policy while keeping it explicit so + # we can tune SM100-specific shapes later if needed. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: # fp32 + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + assert mX.element_type == self.dtype + self._set_cluster_n() + # If N is not divisible by the full 128-bit vector width, step down + # to the largest compatible vector size as in Quack. + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel( + mX, + mTarget, + mLoss, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mTarget, + mLoss, + mLSE, + ignore_index, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_logits: cute.Pointer, + ptr_target: cute.Pointer, + ptr_loss: cute.Pointer, + ptr_lse: cute.Pointer, + M: Int32, + ld: Int32, + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_m = cute.make_layout((M,), stride=(1,)) + mX = cute.make_tensor(ptr_logits, layout_mn) + mTarget = cute.make_tensor(ptr_target, layout_m) + mLoss = cute.make_tensor(ptr_loss, layout_m) + mLSE = cute.make_tensor(ptr_lse, layout_m) + self.__call__(mX, mTarget, mLoss, mLSE, ignore_index, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + ignore_index: Int32, # Index to ignore in loss computation + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + shape: cute.Shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # Quack-style CTA tiling: let CuTe compute the CTA offsets directly. + # (Avoids the extra 64-bit address arithmetic in `domain_offset_i64` on + # the common inference/benchmark sizes.) + gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y)) + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) + + # Copy setup: gmem -> smem via cp.async, 128-bit or narrower as needed. + num_copy_elems_X = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + num_copy_bits_X = mX.element_type.width * num_copy_elems_X + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + gX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_X)) + thr_copy_X = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout + ).get_slice(tidx) + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + tXrX = cute.make_fragment_like(tXgX) + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + row = tXcX[0][0] + target = Int32.zero + if row < shape[0]: + target = Int32(mTarget[row]) + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if row < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + # Fill out-of-bounds values with -inf so they are ignored in max/sum. + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + should_ignore = Boolean(target == ignore_index) + + # Load the target logit if this row is not ignored. + target_logit = Float32.zero + if row < shape[0] and tXcX[0][1] == 0 and not should_ignore: + target_logit = Float32(mX[row, target]) + + max_x, denom, _ = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + phase=None, + return_exp_x=False, + ) + + # Write loss and lse to gmem. Only one CTA in the cluster writes to + # avoid duplicate stores. + if ( + tXcX[0][1] == 0 + and row < shape[0] + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + lse = max_x + cute.math.log(denom, fastmath=True) + loss_val = (lse - target_logit) if not should_ignore else Float32.zero + mLoss[row] = mLoss.element_type(loss_val) + if const_expr(mLSE is not None): + mLSE[row] = lse + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + ignore_index: Int32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl( + mX, + mTarget, + mLoss, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + ignore_index: Int32, + ) -> None: + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + self._kernel_impl( + mX, + mTarget, + mLoss, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + + +class CrossEntropyFwdBwdSM100(ReductionBase): + """Fused cross-entropy forward+backward producing dx from (logits, target, dloss). + + This avoids materializing the intermediate `lse` (and loss) in global memory + when the only desired output is `dx` for `reduction="none"` semantics. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + + def _calculate_threads_per_row(self) -> int: + N = self.N + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) + ) + ) + + def _set_cluster_n(self) -> None: + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + assert mX.element_type == self.dtype + assert mdX.element_type == self.dtype + self._set_cluster_n() + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel( + mX, + mTarget, + mDLoss, + mdX, + ignore_index, + tv_layout, + tiler_mn, + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mTarget, + mDLoss, + mdX, + ignore_index, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_logits: cute.Pointer, + ptr_target: cute.Pointer, + ptr_dloss: cute.Pointer, + ptr_dx: cute.Pointer, + M: Int32, + ld: Int32, + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_m = cute.make_layout((M,), stride=(1,)) + mX = cute.make_tensor(ptr_logits, layout_mn) + mdX = cute.make_tensor(ptr_dx, layout_mn) + mTarget = cute.make_tensor(ptr_target, layout_m) + mDLoss = cute.make_tensor(ptr_dloss, layout_m) + self.__call__(mX, mTarget, mDLoss, mdX, ignore_index, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + ignore_index: Int32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_y = ( + const_expr(0) + if const_expr(self.cluster_n == 1) + else cute.arch.block_idx()[1] + ) + + shape: cute.Shape = mX.shape + idX = cute.make_identity_tensor(shape) + + gX, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mdX, idX) + ] + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) + + num_copy_elems_X = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + num_copy_bits_X = mX.element_type.width * num_copy_elems_X + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + gX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + copy_atom_store_dX = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gdX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_X)) + thr_copy_X = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout + ).get_slice(tidx) + thr_copy_dX = cute.make_tiled_copy_tv( + copy_atom_store_dX, thr_layout, val_layout + ).get_slice(tidx) + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + tXcFull = thr_copy_X.partition_S(cX) + tXgdX = thr_copy_dX.partition_D(gdX) + + tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)] + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + row = tXcX[0][0] + target = Int32.zero + dloss = Float32.zero + if row < shape[0]: + target = Int32(mTarget[row]) + should_ignore = Boolean(target == ignore_index) + dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if row < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + _max_x, denom, exp_x = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + phase=None, + return_exp_x=True, + ) + assert exp_x is not None + probs = exp_x * cute.arch.rcp_approx(denom) + prob_shifted = probs - 1.0 + + mask = cute.make_fragment_like(tXrX, cutlass.Boolean) + for i in cutlass.range(cute.size(tXcFull), unroll_full=True): + mask[i] = tXcFull[i][1] == target + grad = cute.where(mask.load(), prob_shifted, probs) + grad = grad * dloss + + tXrdX.store(grad.to(tXrdX.element_type)) + + tXpdX = ( + predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if row < shape[0]: + cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + ignore_index: Int32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl( + mX, + mTarget, + mDLoss, + mdX, + ignore_index, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + ignore_index: Int32, + ) -> None: + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + self._kernel_impl( + mX, + mTarget, + mDLoss, + mdX, + ignore_index, + tv_layout, + tiler_mn, + ) + + +class CrossEntropyBackwardSM100: + """SM100-tuned cross-entropy backward kernel. + + This is a direct port of ``quack.cross_entropy.CrossEntropyBackward`` to + the local lite_quack helpers, using cp.async tiling over the (M, N) + logits and broadcasting ``dloss`` / ``lse`` across the row dimension. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + self.dtype = dtype + self.N = N + + def _get_num_threads(self) -> int: + # Keep in sync with _get_tv_layout() (we tile N in 16k blocks). + N = min(self.N, 16384) + return 128 if N <= 16384 else 256 + + def _calculate_threads_per_row(self) -> int: + N = min(self.N, 16384) # We split by blocks of 16k in N. + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) + ) + ) + + def _get_tv_layout( + self, num_copy_bits: int = 128 + ) -> tuple[cute.Shape, cute.Layout]: + vecsize = num_copy_bits // self.dtype.width + assert self.N % vecsize == 0, ( + f"Input N {self.N} is not divisible by vector size {vecsize}" + ) + N = min(self.N, 16384) + num_threads = 128 if N <= 16384 else 256 + threads_per_row = self._calculate_threads_per_row() + cols_per_block = num_threads // threads_per_row + num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row) + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + tv_layout = cute.make_layout( + ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * threads_per_row), + ), + ) + return tiler_mn, tv_layout + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mTarget: cute.Tensor, + mDLoss: cute.Tensor, + mdX: cute.Tensor, + mLSE: cute.Tensor, + ignore_index: Int32, # Index to ignore in gradient computation + stream: cuda.CUstream, + ) -> None: + assert mX.element_type == self.dtype + assert mdX.element_type == self.dtype + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + # Broadcast (M,) tensors along the N dimension with stride 0. + mDLoss, mTarget, mLSE = [ + cute.make_tensor( + X.iterator, + cute.append(X.layout, cute.make_layout((self.N,), stride=(0,))), + ) + for X in (mDLoss, mTarget, mLSE) + ] + smem_size = cute.size_in_bytes( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + kernel = ( + self.kernel( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + ) + ) + kernel.launch( + grid=[ + cute.ceil_div(mX.shape[0], tiler_mn[0]), + cute.ceil_div(mX.shape[1], tiler_mn[1]), + 1, + ], + block=[num_threads, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_logits: cute.Pointer, + ptr_target: cute.Pointer, + ptr_dloss: cute.Pointer, + ptr_dx: cute.Pointer, + ptr_lse: cute.Pointer, + M: Int32, + ld: Int32, + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_m = cute.make_layout((M,), stride=(1,)) + mX = cute.make_tensor(ptr_logits, layout_mn) + mdX = cute.make_tensor(ptr_dx, layout_mn) + mTarget = cute.make_tensor(ptr_target, layout_m) + mDLoss = cute.make_tensor(ptr_dloss, layout_m) + mLSE = cute.make_tensor(ptr_lse, layout_m) + self.__call__(mX, mTarget, mDLoss, mdX, mLSE, ignore_index, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + mLSE: cute.Tensor, # (M,) + ignore_index: Int32, # Index to ignore in gradient computation + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + shape = mX.shape + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + + idX = cute.make_identity_tensor(shape) + # Quack-style CTA tiling: avoid extra 64-bit address arithmetic by + # letting CuTe compute the CTA offsets directly. + gX, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, bidy)) for mT in (mX, mdX, idX) + ] + + num_copy_elems_X = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + num_copy_bits_X = mX.element_type.width * num_copy_elems_X + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + gX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + copy_atom_store_dX = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gdX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + thr_copy_X = cute.make_tiled_copy( + copy_atom_load_X, tv_layout, tiler_mn + ).get_slice(tidx) + thr_copy_dX = cute.make_tiled_copy( + copy_atom_store_dX, tv_layout, tiler_mn + ).get_slice(tidx) + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + tXcFull = thr_copy_X.partition_S(cX) + tXgdX = thr_copy_dX.partition_D(gdX) + + tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)] + + is_even_N = const_expr(shape[1] % tiler_mn[1] == 0) + row = tXcX[0][0] + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if not is_even_N + else None + ) + if row < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + target = Int32.zero + dloss = Float32.zero + lse = Float32.zero + if row < shape[0]: + target = Int32(mTarget[row]) + should_ignore = Boolean(target == ignore_index) + dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero + lse = Float32(mLSE[row]) + + log2_e = math.log2(math.e) + probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True) + prob_shifted = probs - 1.0 + mask = cute.make_fragment_like(tXrX, cutlass.Boolean) + for i in cutlass.range(cute.size(tXcFull), unroll_full=True): + mask[i] = tXcFull[i][1] == target + grad = cute.where(mask.load(), prob_shifted, probs) + grad = grad * dloss + + tXrdX.store(grad.to(tXrdX.element_type)) + tXpdX = ( + predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1]) + if not is_even_N + else None + ) + if row < shape[0]: + cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + mLSE: cute.Tensor, # (M,) + ignore_index: Int32, # Index to ignore in gradient computation + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + mLSE: cute.Tensor, # (M,) + ignore_index: Int32, # Index to ignore in gradient computation + ) -> None: + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + self._kernel_impl( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + + +def cross_entropy_forward( + logits: Tensor, + target: Tensor, + ignore_index: int = -100, + reduction: Literal["none", "mean", "sum"] = "none", +) -> tuple[Tensor, Tensor]: + """SM100 CuteDSL cross-entropy forward pass. + + Args: + logits: Tensor of shape ``(M, N)`` on CUDA. + target: Tensor of shape ``(M,)`` with integer class indices. + ignore_index: Target value to ignore when computing the loss. + reduction: One of ``"none"``, ``"mean"``, or ``"sum"`` following + ``torch.nn.functional.cross_entropy`` semantics. + + Returns: + A tuple ``(loss, lse)`` where: + - ``loss`` has shape ``(M,)`` if ``reduction="none"`` or is a scalar + otherwise. + - ``lse`` is the per-example log-sum-exp with shape ``(M,)``. + """ + assert logits.dim() == 2, "logits must be 2D (M, N)" + assert target.dim() == 1, "target must be 1D (M,)" + assert logits.shape[0] == target.shape[0], "Batch dimensions must match" + assert logits.is_cuda and target.is_cuda, "logits and target must be on CUDA device" + assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype" + assert target.dtype in (torch.int32, torch.int64), "target must be int32 or int64" + + M, N = logits.shape + device = logits.device + dtype_cute = TORCH2CUTE_DTYPE[logits.dtype] + + loss = torch.empty(M, device=device, dtype=torch.float32) + lse = torch.empty(M, device=device, dtype=torch.float32) + + if _can_use_ptr_path_logits(logits) and _can_use_ptr_path_target(target): + _cross_entropy_forward_ptr_into( + logits=logits, + target=target, + loss=loss, + lse=lse, + ignore_index=int(ignore_index), + ) + if reduction == "none": + return loss, lse + with torch.no_grad(): + mask = target != ignore_index + if reduction == "sum": + reduced = loss.sum() + elif reduction == "mean": + valid = mask.sum() + if valid > 0: + reduced = loss[mask].sum() / valid.to(loss.dtype) + else: + reduced = loss.sum() * 0.0 + else: + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected 'none', 'mean', or 'sum'." + ) + return reduced, lse + + mX = _convert_logits_2d(logits) + mTarget = _convert_1d(target.to(torch.int64), assumed_align=8) + mLoss = _convert_1d(loss, assumed_align=4) + mLSE = _convert_1d(lse, assumed_align=4) + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = (dtype_cute, N) + kernel = _FWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = CrossEntropyFwdSM100(dtype_cute, N) + kernel = cute.compile( + op, + mX, + mTarget, + mLoss, + mLSE, + Int32(ignore_index), + current_stream, + ) + _FWD_COMPILE_CACHE[compile_key] = kernel + + kernel(mX, mTarget, mLoss, mLSE, Int32(ignore_index), current_stream) + + if reduction == "none": + return loss, lse + + with torch.no_grad(): + mask = target != ignore_index + if reduction == "sum": + reduced = loss.sum() + elif reduction == "mean": + valid = mask.sum() + if valid > 0: + reduced = loss[mask].sum() / valid.to(loss.dtype) + else: + reduced = loss.sum() * 0.0 + else: + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected 'none', 'mean', or 'sum'." + ) + return reduced, lse + + +def _cross_entropy_backward_sm100( + logits: Tensor, + target: Tensor, + dloss: Tensor, + lse: Tensor, + dx: Tensor, + ignore_index: int = -100, +) -> None: + """Internal SM100 cross-entropy backward dispatch using CuteDSL.""" + assert logits.dim() == 2, "logits must be 2D (M, N)" + assert target.dim() == 1, "target must be 1D (M,)" + assert dloss.dim() == 1, "dloss must be 1D (M,)" + assert lse.dim() == 1, "lse must be 1D (M,)" + assert logits.shape[0] == target.shape[0] == dloss.shape[0] == lse.shape[0], ( + "Batch dimensions must match" + ) + assert logits.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, ( + "All tensors must be on CUDA device" + ) + assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype" + assert target.dtype in (torch.int32, torch.int64), "target must be int32 or int64" + + M, N = logits.shape + dtype_cute = TORCH2CUTE_DTYPE[logits.dtype] + + if ( + _can_use_ptr_path_logits(logits) + and _can_use_ptr_path_logits(dx) + and _can_use_ptr_path_target(target) + and _can_use_ptr_path_f32_1d(dloss) + and _can_use_ptr_path_f32_1d(lse) + and logits.stride() == dx.stride() + ): + _cross_entropy_backward_ptr_into( + logits=logits, + target=target, + dloss=dloss, + lse=lse, + dx=dx, + ignore_index=int(ignore_index), + ) + return + + mX = _convert_logits_2d(logits) + mdX = _convert_logits_2d(dx) + mTarget = _convert_1d(target.to(torch.int64), assumed_align=8) + mDLoss = _convert_1d(dloss.to(torch.float32), assumed_align=4) + mLSE = _convert_1d(lse.to(torch.float32), assumed_align=4) + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = (dtype_cute, N) + kernel = _BWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = CrossEntropyBackwardSM100(dtype_cute, N) + kernel = cute.compile( + op, + mX, + mTarget, + mDLoss, + mdX, + mLSE, + Int32(ignore_index), + current_stream, + ) + _BWD_COMPILE_CACHE[compile_key] = kernel + + kernel(mX, mTarget, mDLoss, mdX, mLSE, Int32(ignore_index), current_stream) + + +def _can_use_ptr_path_logits(x: Tensor) -> bool: + if not x.is_cuda or x.dim() != 2: + return False + if x.dtype not in TORCH2CUTE_DTYPE: + return False + if x.stride(1) != 1: + return False + if (x.data_ptr() % 16) != 0: + return False + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + divby = 128 // dtype_x.width + if (x.stride(0) % divby) != 0: + return False + return True + + +def _can_use_ptr_path_target(t: Tensor) -> bool: + if not t.is_cuda or t.dim() != 1: + return False + if t.dtype is not torch.int64: + return False + if not t.is_contiguous(): + return False + if t.stride(0) != 1: + return False + if (t.data_ptr() % 8) != 0: + return False + return True + + +def _can_use_ptr_path_f32_1d(t: Tensor) -> bool: + if not t.is_cuda or t.dim() != 1: + return False + if t.dtype is not torch.float32: + return False + if not t.is_contiguous(): + return False + if t.stride(0) != 1: + return False + if (t.data_ptr() % 4) != 0: + return False + return True + + +def _cross_entropy_forward_ptr_into( + *, + logits: Tensor, + target: Tensor, + loss: Tensor, + lse: Tensor, + ignore_index: int, +) -> None: + assert logits.is_cuda and logits.dim() == 2 + assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0] + assert target.dtype is torch.int64 + assert ( + loss.is_cuda + and loss.shape == (logits.shape[0],) + and loss.dtype is torch.float32 + ) + assert ( + lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32 + ) + + M, N = logits.shape + device_index = logits.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[logits.dtype] + key = ("ptr_fwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_FWD_COMPILE_CACHE.get(key) + if compiled is None: + op = CrossEntropyFwdSM100(dtype_x, int(N)) + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_loss = rt.make_ptr( + cutlass.Float32, + loss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_lse = rt.make_ptr( + cutlass.Float32, + lse.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_logits, + ptr_target, + ptr_loss, + ptr_lse, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + _PTR_FWD_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_cross_entropy_launcher( + compiled=compiled, + dtype_logits=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + mode="fwd", + ) + if launcher is not None: + launcher.launch( + logits_ptr=int(logits.data_ptr()), + target_ptr=int(target.data_ptr()), + aux_a_ptr=int(loss.data_ptr()), + aux_b_ptr=int(lse.data_ptr()), + aux_c_ptr=None, + M=int(M), + ld=int(logits.stride(0)), + ignore_index=int(ignore_index), + stream_handle=stream_handle, + dtype_logits=dtype_x, + aux_a_dtype=cutlass.Float32, + aux_b_dtype=cutlass.Float32, + aux_c_dtype=None, + ) + return + + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_loss = rt.make_ptr( + cutlass.Float32, + loss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_lse = rt.make_ptr( + cutlass.Float32, + lse.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + compiled( + ptr_logits, + ptr_target, + ptr_loss, + ptr_lse, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + + +def _cross_entropy_backward_ptr_into( + *, + logits: Tensor, + target: Tensor, + dloss: Tensor, + lse: Tensor, + dx: Tensor, + ignore_index: int, +) -> None: + assert logits.is_cuda and logits.dim() == 2 + assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0] + assert target.dtype is torch.int64 + assert ( + dloss.is_cuda + and dloss.shape == (logits.shape[0],) + and dloss.dtype is torch.float32 + ) + assert ( + lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32 + ) + assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype + assert dx.stride() == logits.stride(), ( + "Pointer path expects dx to match logits strides" + ) + + M, N = logits.shape + device_index = logits.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[logits.dtype] + key = ("ptr_bwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_BWD_COMPILE_CACHE.get(key) + if compiled is None: + op = CrossEntropyBackwardSM100(dtype_x, int(N)) + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_dloss = rt.make_ptr( + cutlass.Float32, + dloss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_lse = rt.make_ptr( + cutlass.Float32, + lse.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_logits, + ptr_target, + ptr_dloss, + ptr_dx, + ptr_lse, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + _PTR_BWD_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_cross_entropy_launcher( + compiled=compiled, + dtype_logits=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + mode="bwd", + ) + if launcher is not None: + launcher.launch( + logits_ptr=int(logits.data_ptr()), + target_ptr=int(target.data_ptr()), + aux_a_ptr=int(dloss.data_ptr()), + aux_b_ptr=int(dx.data_ptr()), + aux_c_ptr=int(lse.data_ptr()), + M=int(M), + ld=int(logits.stride(0)), + ignore_index=int(ignore_index), + stream_handle=stream_handle, + dtype_logits=dtype_x, + aux_a_dtype=cutlass.Float32, + aux_b_dtype=dtype_x, + aux_c_dtype=cutlass.Float32, + ) + return + + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_dloss = rt.make_ptr( + cutlass.Float32, + dloss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_lse = rt.make_ptr( + cutlass.Float32, + lse.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + compiled( + ptr_logits, + ptr_target, + ptr_dloss, + ptr_dx, + ptr_lse, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + + +def _cross_entropy_fwd_bwd_ptr_into( + *, + logits: Tensor, + target: Tensor, + dloss: Tensor, + dx: Tensor, + ignore_index: int, +) -> None: + """Launch the fused pointer-based cross-entropy fwd+bwd kernel into preallocated `dx`.""" + assert logits.is_cuda and logits.dim() == 2 + assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0] + assert target.dtype is torch.int64 + assert ( + dloss.is_cuda + and dloss.shape == (logits.shape[0],) + and dloss.dtype is torch.float32 + ) + assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype + assert dx.stride() == logits.stride(), ( + "Pointer path expects dx to match logits strides" + ) + + M, N = logits.shape + device_index = logits.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[logits.dtype] + key = ("ptr_fwd_bwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_FWDBWD_COMPILE_CACHE.get(key) + if compiled is None: + op = CrossEntropyFwdBwdSM100(dtype_x, int(N)) + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_dloss = rt.make_ptr( + cutlass.Float32, + dloss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_logits, + ptr_target, + ptr_dloss, + ptr_dx, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + _PTR_FWDBWD_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_cross_entropy_launcher( + compiled=compiled, + dtype_logits=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + mode="fwd_bwd", + ) + if launcher is not None: + launcher.launch( + logits_ptr=int(logits.data_ptr()), + target_ptr=int(target.data_ptr()), + aux_a_ptr=int(dloss.data_ptr()), + aux_b_ptr=int(dx.data_ptr()), + aux_c_ptr=None, + M=int(M), + ld=int(logits.stride(0)), + ignore_index=int(ignore_index), + stream_handle=stream_handle, + dtype_logits=dtype_x, + aux_a_dtype=cutlass.Float32, + aux_b_dtype=dtype_x, + aux_c_dtype=None, + ) + return + + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_dloss = rt.make_ptr( + cutlass.Float32, + dloss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + compiled( + ptr_logits, + ptr_target, + ptr_dloss, + ptr_dx, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + + +def cross_entropy_backward( + dloss: Tensor, + logits: Tensor, + target: Tensor, + lse: Tensor, + ignore_index: int = -100, +) -> Tensor: + """SM100 CuteDSL cross-entropy backward pass. + + Args: + dloss: Upstream gradient of shape ``(M,)`` corresponding to + ``reduction="none"``. + logits: Input logits tensor of shape ``(M, N)``. + target: Integer class indices of shape ``(M,)``. + lse: Per-example log-sum-exp tensor of shape ``(M,)`` as returned + by :func:`cross_entropy_forward`. + ignore_index: Target value to ignore in gradient computation. + + Returns: + ``dlogits`` of shape ``(M, N)`` with the same dtype as ``logits``. + """ + assert logits.dim() == 2, "logits must be 2D (M, N)" + assert dloss.dim() == 1, "dloss must be 1D (M,)" + assert logits.size(0) == dloss.size(0), "Batch dimensions must match" + assert logits.is_cuda and dloss.is_cuda, "logits and dloss must be on CUDA device" + + dx = torch.empty_like(logits) + _cross_entropy_backward_sm100( + logits, + target, + dloss, + lse, + dx, + ignore_index=ignore_index, + ) + return dx + + +def cross_entropy_fwd_bwd( + dloss: Tensor, + logits: Tensor, + target: Tensor, + ignore_index: int = -100, +) -> Tensor: + """Fused cross-entropy forward+backward producing ``dx`` for ``reduction='none'``. + + Computes per-logit gradients ``dx`` given: + - ``logits``: (M, N) + - ``target``: (M,) + - ``dloss``: (M,) upstream gradients (float32 recommended) + + The fast path avoids materializing intermediate ``lse`` in global memory. + """ + assert logits.dim() == 2, "logits must be 2D (M, N)" + assert target.dim() == 1, "target must be 1D (M,)" + assert dloss.dim() == 1, "dloss must be 1D (M,)" + assert logits.shape[0] == target.shape[0] == dloss.shape[0], ( + "Batch dimensions must match" + ) + assert logits.is_cuda and target.is_cuda and dloss.is_cuda, ( + "All tensors must be on CUDA device" + ) + assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype" + + dx = torch.empty_like(logits) + + if ( + _can_use_ptr_path_logits(logits) + and _can_use_ptr_path_logits(dx) + and _can_use_ptr_path_target(target) + and _can_use_ptr_path_f32_1d(dloss) + and logits.stride() == dx.stride() + ): + _cross_entropy_fwd_bwd_ptr_into( + logits=logits, + target=target, + dloss=dloss, + dx=dx, + ignore_index=int(ignore_index), + ) + return dx + + # Fallback: reuse the existing forward+backward kernels (DLPack path handles + # any necessary dtype conversions). + with torch.no_grad(): + _loss, lse = cross_entropy_forward( + logits, + target, + ignore_index=int(ignore_index), + reduction="none", + ) + _cross_entropy_backward_sm100( + logits, + target, + dloss, + lse, + dx, + ignore_index=int(ignore_index), + ) + return dx + + +def cross_entropy( + logits: Tensor, + target: Tensor, + ignore_index: int = -100, + reduction: Literal["none", "mean", "sum"] = "mean", +) -> Tensor: + """Convenience wrapper mirroring ``torch.nn.functional.cross_entropy`` reductions. + + This uses :func:`cross_entropy_forward` under the hood but returns only + the reduced loss tensor. + """ + loss, _lse = cross_entropy_forward( + logits, + target, + ignore_index=ignore_index, + reduction="none", + ) + if reduction == "none": + return loss + mask = target != ignore_index + if reduction == "sum": + return loss.sum() + if reduction == "mean": + valid = mask.sum() + if valid > 0: + return loss[mask].sum() / valid.to(loss.dtype) + return loss.sum() * 0.0 + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'." + ) + + +def verify_cross_entropy_parity( + M: int, + N: int, + dtype: torch.dtype = torch.bfloat16, + ignore_index: int = -100, +) -> None: + """Compare SM100 CuteDSL cross-entropy against PyTorch for a single shape.""" + device = torch.device("cuda") + torch.manual_seed(0) + + logits = 0.1 * torch.randn(M, N, device=device, dtype=dtype) + logits.requires_grad_(True) + target = torch.randint(0, N, (M,), device=device, dtype=torch.int64) + + # Optionally sprinkle some ignore_index entries for robustness. + if ignore_index != -100: + mask = torch.rand(M, device=device) < 0.1 + target[mask] = ignore_index + + loss, lse = cross_entropy_forward( + logits, target, ignore_index=ignore_index, reduction="none" + ) + + logits_ref = logits.detach().clone().requires_grad_() + target_ref = target.detach().clone() + loss_ref = torch.nn.functional.cross_entropy( + logits_ref.float(), + target_ref, + ignore_index=ignore_index, + reduction="none", + ) + + # Forward parity + if dtype in (torch.float16, torch.bfloat16): + atol = 5e-2 + rtol = 5e-2 + else: + atol = 1e-4 + rtol = 1e-4 + torch.testing.assert_close(loss, loss_ref, atol=atol, rtol=rtol) + + # Backward parity + dloss = torch.randn_like(loss_ref) + (dx_ref,) = torch.autograd.grad(loss_ref, logits_ref, grad_outputs=dloss) + dx = cross_entropy_backward(dloss, logits, target, lse, ignore_index=ignore_index) + torch.testing.assert_close(dx, dx_ref.to(logits.dtype), atol=atol, rtol=rtol) + + +if __name__ == "__main__": + # Minimal functional check when executed directly. For performance + # comparisons and detailed tuning, use the dedicated benchmark harness. + if not torch.cuda.is_available(): + print("CUDA not available; cross-entropy parity check skipped.") + raise SystemExit(0) + + M, N = 1024, 8192 + dtype = torch.bfloat16 + verify_cross_entropy_parity(M, N, dtype=dtype, ignore_index=-100) + print("SM100 cross-entropy CuteDSL parity check passed.") diff --git a/oink/src/kernelagent_oink/blackwell/fast_launch.py b/oink/src/kernelagent_oink/blackwell/fast_launch.py new file mode 100644 index 0000000..9b288f2 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/fast_launch.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Host-side fast-launch helpers for CuTeDSL pointer entrypoints. + +CuTeDSL's Python runtime typically marshals each kernel call by allocating +`Int32` / `Float32` wrappers and runtime `Pointer` descriptors per invocation. +For latency-sensitive cases (small/medium M), this overhead can dominate. + +These helpers provide: +- Stable scalar argument wrappers (`StableI32Arg`, `StableF32Arg`) that avoid + per-call ctypes allocations. +- In-place mutation of runtime pointer descriptors (`set_runtime_ptr`) so a + compiled kernel can be launched repeatedly with different raw device pointers + without rebuilding argument objects. +- A small thread-local cache to store packed args objects (when supported by the + installed CuTeDSL version). + +All of this relies on a few private-ish CuTeDSL internals. Callers must treat +fast-launch as an optional optimization and fall back to the normal launch +path if those internals are unavailable. +""" + +from __future__ import annotations + +import ctypes +import os +import threading +from typing import Any + +_FAST_LAUNCH_TLS = threading.local() + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name) + if val is None: + return default + return val.strip().lower() not in {"0", "false", "no", "off", ""} + + +# Fast-launch uses internal CuTeDSL plumbing (packed args + pointer descriptors). +# Keep it enabled by default in our pinned environment, but allow disabling it +# via env var and auto-disable it if CuTeDSL internals change. +_ENABLE_FAST_LAUNCH = _env_flag("OINK_CUTEDSL_FAST_LAUNCH", default=True) +_FAST_LAUNCH_SUPPORTED = True + + +def fast_launch_enabled() -> bool: + return _ENABLE_FAST_LAUNCH and _FAST_LAUNCH_SUPPORTED + + +def disable_fast_launch() -> None: + global _FAST_LAUNCH_SUPPORTED + _FAST_LAUNCH_SUPPORTED = False + + +def tls_cache() -> dict[tuple[Any, ...], Any]: + cache = getattr(_FAST_LAUNCH_TLS, "cache", None) + if cache is None: + cache = {} + _FAST_LAUNCH_TLS.cache = cache + return cache + + +class StableI32Arg: + """A stable Int32 runtime arg (avoids per-call Int32().__c_pointers__ allocations).""" + + def __init__(self, value: int): + self._c_value = ctypes.c_int32(int(value)) + self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p) + + def set(self, value: int) -> None: + self._c_value.value = int(value) + + def __c_pointers__(self): + return [self._c_pointer] + + +class StableF32Arg: + """A stable Float32 runtime arg (avoids per-call Float32().__c_pointers__ allocations).""" + + def __init__(self, value: float): + self._c_value = ctypes.c_float(float(value)) + self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p) + + def set(self, value: float) -> None: + self._c_value.value = float(value) + + def __c_pointers__(self): + return [self._c_pointer] + + +def set_runtime_ptr(ptr: Any, device_ptr: int) -> None: + """Update a CuTeDSL runtime Pointer descriptor in-place. + + This relies on internal runtime pointer fields (`_desc`, `_pointer`, etc.). + If these internals change in a future CuTeDSL upgrade, this function may + raise AttributeError; callers should catch it and fall back. + """ + device_ptr = int(device_ptr) + ptr._pointer = device_ptr # type: ignore[attr-defined] + if getattr(ptr, "_c_pointer", None) is None: + ptr.__c_pointers__() # type: ignore[attr-defined] + ptr._desc.value = device_ptr # type: ignore[attr-defined] diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py new file mode 100644 index 0000000..ada51ec --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/layernorm.py @@ -0,0 +1,1981 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +LayerNorm kernel for SM100 (Blackwell) in CuteDSL. + +This implementation: +- Mirrors Quack's LayerNorm tiling / cluster policy / cp.async pipeline + but uses only local helpers so that it does not depend on the external + `quack` package at runtime. +- Supports fp16 / bf16 / fp32 inputs with fp32 accumulation. +- Optionally writes out per-row `rstd` and `mean` buffers for reuse in + backward or fused kernels. + +Backward is implemented with dedicated CuteDSL kernels for input and +parameter gradients (dx, dweight, dbias), avoiding PyTorch autograd +while matching `torch.nn.functional.layer_norm`'s gradients numerically. +""" + +from __future__ import annotations + +import importlib.metadata +import math +import os +import re +import operator +from typing import Optional, Tuple, Type + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +# CuTeDSL caches generated MLIR into a tempdir under a global default +# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across +# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy +# warnings (and disables cache reuse). +if "CUTE_DSL_CACHE_DIR" not in os.environ: + try: + _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + _dsl_ver = "unknown" + _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) + _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + _tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" + ) + +try: + import cutlass # type: ignore # noqa: F401 +except Exception as e: + raise ImportError( + "kernelagent_oink.blackwell.layernorm requires CuTeDSL's Python package " + "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." + ) from e + +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import runtime as rt +from cutlass.cute.runtime import from_dlpack + +from kernelagent_oink.blackwell.lite_quack import ( + _KERNEL_ACCEPTS_LAYOUT_ARGS, + TORCH2CUTE_DTYPE, + ReductionBase as _ReductionBase, + convert_from_dlpack as convert_from_dlpack_cute, + get_sm_count, + predicate_k, + row_reduce, + warp_reduce, +) +from kernelagent_oink.blackwell.fast_launch import ( + StableF32Arg, + StableI32Arg, + disable_fast_launch, + fast_launch_enabled, + set_runtime_ptr, + tls_cache as _tls_fast_launch_cache, +) + +# Simple compile cache for the forward kernel +_COMPILE_CACHE: dict[Tuple[int, type[cutlass.Numeric], bool, bool, bool], object] = {} +_PTR_COMPILE_CACHE: dict[Tuple[object, ...], object] = {} + +# Backward compile caches: one for dx, one for parameter gradients. +_BWD_DX_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric]], object] = {} +_BWD_PARAM_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric], bool], object] = {} + + +class _PtrLayernormFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: object, + ptr_b: Optional[object], + ptr_out: object, + ptr_rstd: Optional[object], + ptr_mean: Optional[object], + arg_m: StableI32Arg, + arg_ld: StableI32Arg, + arg_eps: StableF32Arg, + stream: cuda.CUstream, + assumed_align_xo: int, + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_b = ptr_b + self._ptr_out = ptr_out + self._ptr_rstd = ptr_rstd + self._ptr_mean = ptr_mean + self._arg_m = arg_m + self._arg_ld = arg_ld + self._arg_eps = arg_eps + self._stream = stream + self._assumed_align_xo = int(assumed_align_xo) + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_b_ptr = -1 + self._last_out_ptr = -1 + self._last_rstd_ptr = -1 + self._last_mean_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_eps = float("nan") + + def launch( + self, + *, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + out: Tensor, + rstd: Optional[Tensor], + mean: Optional[Tensor], + M: int, + ld: int, + eps: float, + ) -> None: + if not fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + x_ptr = x.data_ptr() + if x_ptr != self._last_x_ptr: + try: + set_runtime_ptr(self._ptr_x, x_ptr) + self._last_x_ptr = x_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + w_ptr = weight.data_ptr() + if w_ptr != self._last_w_ptr: + try: + set_runtime_ptr(self._ptr_w, w_ptr) + self._last_w_ptr = w_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + if self._ptr_b is not None and bias is not None: + b_ptr = bias.data_ptr() + if b_ptr != self._last_b_ptr: + try: + set_runtime_ptr(self._ptr_b, b_ptr) + self._last_b_ptr = b_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + out_ptr = out.data_ptr() + if out_ptr != self._last_out_ptr: + try: + set_runtime_ptr(self._ptr_out, out_ptr) + self._last_out_ptr = out_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + if self._ptr_rstd is not None and rstd is not None: + rstd_ptr = rstd.data_ptr() + if rstd_ptr != self._last_rstd_ptr: + try: + set_runtime_ptr(self._ptr_rstd, rstd_ptr) + self._last_rstd_ptr = rstd_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + if self._ptr_mean is not None and mean is not None: + mean_ptr = mean.data_ptr() + if mean_ptr != self._last_mean_ptr: + try: + set_runtime_ptr(self._ptr_mean, mean_ptr) + self._last_mean_ptr = mean_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if eps != self._last_eps: + self._arg_eps.set(eps) + self._last_eps = eps + + if self._cuda_result is not None: + self._cuda_result.value = 0 + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + self._use_fast_launch = False + disable_fast_launch() + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + out: Tensor, + rstd: Optional[Tensor], + mean: Optional[Tensor], + M: int, + ld: int, + eps: float, + ) -> None: + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + ptr_x = rt.make_ptr( + dtype_x, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_xo, + ) + ptr_out = rt.make_ptr( + dtype_x, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_xo, + ) + ptr_w = rt.make_ptr( + cutlass.Float32, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + ptr_b = ( + rt.make_ptr( + cutlass.Float32, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if rstd is not None + else None + ) + ptr_mean = ( + rt.make_ptr( + cutlass.Float32, + mean.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if mean is not None + else None + ) + self._compiled( + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + Int32(int(M)), + Int32(int(ld)), + stream, + Float32(float(eps)), + ) + + +def _get_fast_ptr_layernorm_launcher( + *, + compiled: object, + N: int, + dtype_x: type[cutlass.Numeric], + has_bias: bool, + has_rstd: bool, + has_mean: bool, + device_index: int, + stream_handle: int, + assumed_align_xo: int, + eps: float, +) -> Optional[_PtrLayernormFastLaunch]: + if not fast_launch_enabled(): + return None + key = ( + "ptr_fast", + id(compiled), + int(N), + dtype_x, + bool(has_bias), + bool(has_rstd), + bool(has_mean), + int(device_index), + int(stream_handle), + int(assumed_align_xo), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + ptr_x = rt.make_ptr( + dtype_x, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_xo) + ) + ptr_out = rt.make_ptr( + dtype_x, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_xo) + ) + ptr_w = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_b = ( + rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + if has_bias + else None + ) + ptr_rstd = ( + rt.make_ptr(cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4) + if has_rstd + else None + ) + ptr_mean = ( + rt.make_ptr(cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4) + if has_mean + else None + ) + + arg_m = StableI32Arg(0) + arg_ld = StableI32Arg(N) + arg_eps = StableF32Arg(eps) + stream = cuda.CUstream(int(stream_handle)) + executor = compiled.to(device_index) # type: ignore[attr-defined] + + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + arg_m, + arg_ld, + stream, + arg_eps, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + disable_fast_launch() + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + arg_m, + arg_ld, + arg_eps, + stream, + *adapted_args, + ) + launcher = _PtrLayernormFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_b=ptr_b, + ptr_out=ptr_out, + ptr_rstd=ptr_rstd, + ptr_mean=ptr_mean, + arg_m=arg_m, + arg_ld=arg_ld, + arg_eps=arg_eps, + stream=stream, + assumed_align_xo=int(assumed_align_xo), + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + +def _convert_row_major(t: Tensor) -> cute.Tensor: + """ + Convert a 2D row-major torch.Tensor to a CuTeDSL tensor with a compact, + dynamic layout on the leading dimension. + """ + return from_dlpack(t.detach(), assumed_align=16).mark_compact_shape_dynamic( + mode=0, + stride_order=(0, 1), + ) + + +class LayerNormSM100(_ReductionBase): + """ + SM100 LayerNorm forward kernel. + + This mirrors `quack.layernorm.LayerNorm`'s schedule: + - Stage=2 pipeline: first pass computes mean, second pass computes + variance / rstd and normalization. + - Threads-per-row and cluster_n policy follow Quack's LayerNorm + heuristics to keep tensor-core friendly tiles across N. + - Optional `reload_from` hint enables reloading X from SMEM for large-N + shapes to shorten register lifetimes. + + Differences vs Quack: + - Bias is optional and supported directly in the kernel. + - Dtype mapping and reduction helpers come from `lite_quack`. + """ + + def __init__( + self, + dtype: type[cutlass.Numeric], + N: int, + *, + copy_bits_x: Optional[int] = None, + direct_gmem: bool = False, + ): + super().__init__(dtype, N, stage=2) # 2 stages for mean and var + # Default reload policy mirrors Quack: use SMEM reload only for + # very large hidden sizes. We keep this conservative for LayerNorm + # and tune primarily via threads-per-block / cluster_n. + self.reload_from: Optional[str] = None if N <= 16384 else "smem" + # SM100 tuning: for DSv3 hidden sizes where we fuse mean+var stats, + # delay loading fp32 weights/bias until after the reductions to lower + # register pressure. + self.delay_w_load: bool = bool(N in (4096, 6144, 7168, 8192)) + self.copy_bits_x: Optional[int] = ( + int(copy_bits_x) if copy_bits_x is not None else None + ) + self.direct_gmem: bool = bool(direct_gmem) + + def _get_num_threads(self) -> int: + nt = getattr(self, "_nt_override", None) + if nt is not None: + return int(nt) + return super()._get_num_threads() + + def _calculate_threads_per_row(self) -> int: + tpr = getattr(self, "_tpr_override", None) + if tpr is not None: + return int(tpr) + # Match Quack's LayerNorm threads-per-row buckets. + N = self.N + if N in (4096, 6144): + return 128 + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) + ) + ) + + def _set_cluster_n(self) -> None: + # Cluster_n policy mirrors quack.layernorm.LayerNorm._set_cluster_n. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + + # Tiling and cluster policy (mirrors Quack LayerNorm). + self._set_cluster_n() + largest_dtype_width = const_expr( + max( + t.element_type.width + for t in (mX, mW, mB, mO, mRstd, mMean) + if t is not None + ) + ) + # Match Quack's unified RMSNorm/LayerNorm kernel: pick vecsize based on + # the widest dtype participating in the op (e.g. fp32 weights => fp16 + # X uses 64b vectorization). + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + default_copy_bits_x = vecsize * self.dtype.width + num_copy_bits_x = ( + int(self.copy_bits_x) + if self.copy_bits_x is not None + else default_copy_bits_x + ) + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits_x) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + + # Expand weight / bias to match tiler_mn[0] rows per CTA. + mW = cute.make_tensor( + mW.iterator, + cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mB is not None): + mB = cute.make_tensor( + mB.iterator, + cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mRstd is not None): + mRstd = cute.make_tensor( + mRstd.iterator, + cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))), + ) + if const_expr(mMean is not None): + mMean = cute.make_tensor( + mMean.iterator, + cute.append(mMean.layout, cute.make_layout((self.N,), stride=(0,))), + ) + + kernel = ( + self.kernel( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + tv_layout, + tiler_mn, + const_expr(self.reload_from), + const_expr(self.delay_w_load), + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[ + 1, + self.cluster_n, + 1, + ] + if const_expr(self.cluster_n > 1) + else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_w: cute.Pointer, + ptr_b: Optional[cute.Pointer], + ptr_out: cute.Pointer, + ptr_rstd: Optional[cute.Pointer], + ptr_mean: Optional[cute.Pointer], + M: Int32, + ld: Int32, + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions. + + This reconstructs cute.Tensor views from raw device pointers + explicit + layouts inside the JIT graph, reusing the tuned LayerNormSM100 schedule. + """ + # Mirror Quack-style divisibility contracts so the compiler can prove + # alignment for vectorized loads/stores (and cp.async when enabled). + divby = ( + int(self.copy_bits_x) // self.dtype.width + if const_expr(self.copy_bits_x is not None) + else (128 // self.dtype.width) + ) + ld_assumed = cute.assume(ld, divby=divby) + # Match `mark_compact_shape_dynamic(mode=0, ...)`: M is dynamic, N is static. + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_n = cute.make_layout((self.N,), stride=(1,)) + layout_m = cute.make_layout((M,), stride=(1,)) + + mX = cute.make_tensor(ptr_x, layout_mn) + mO = cute.make_tensor(ptr_out, layout_mn) + mW = cute.make_tensor(ptr_w, layout_n) + mB = ( + cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None + ) + mRstd = ( + cute.make_tensor(ptr_rstd, layout_m) + if const_expr(ptr_rstd is not None) + else None + ) + mMean = ( + cute.make_tensor(ptr_mean, layout_m) + if const_expr(ptr_mean is not None) + else None + ) + + self.__call__(mX, mW, mB, mO, mRstd, mMean, stream, eps) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + reload_from: cutlass.Constexpr, + delay_w_load: cutlass.Constexpr, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # Quack-style CTA tiling: let CuTe compute the CTA offsets directly. + # (Avoids the extra 64-bit address arithmetic in `domain_offset_i64` on + # the common inference/benchmark sizes.) + gX, gO = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO)] + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) + gB = ( + cute.local_tile(mB, tiler_mn, (0, cluster_y)) + if const_expr(mB is not None) + else None + ) + gRstd = ( + cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y)) + if const_expr(mRstd is not None) + else None + ) + gMean = ( + cute.local_tile(mMean, tiler_mn, (bidx, cluster_y)) + if const_expr(mMean is not None) + else None + ) + + # Copy atoms for X / W / B / O (mirror Quack's vector-size contract). + num_copy_elems_x = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + num_copy_bits_x = mX.element_type.width * num_copy_elems_x + num_copy_bits_x_async = const_expr(min(128, num_copy_bits_x)) + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=num_copy_bits_x, + ) + copy_atom_load_X_async = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=num_copy_bits_x_async, + ) + num_copy_bits_wb = const_expr( + min(128, mW.element_type.width * num_copy_elems_x) + ) + copy_atom_load_WB = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mW.element_type, + num_bits_per_copy=num_copy_bits_wb, + ) + copy_atom_store_O = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mO.element_type, + num_bits_per_copy=num_copy_bits_x, + ) + + # Quack-style partitioning: use `make_tiled_copy_tv` (2D thread/value + # layout) and let partitioning over the CTA tile handle the N loop. + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_x)) + thr_copy = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout + ).get_slice(tidx) + + tXgX = thr_copy.partition_S(gX) + tXsX = thr_copy.partition_D(sX) + tXgO = thr_copy.partition_D(gO) + tXgW = thr_copy.partition_S(gW) + tXgB = thr_copy.partition_S(gB) if const_expr(gB is not None) else None + tXrRstd = thr_copy.partition_D(gRstd) if const_expr(mRstd is not None) else None + tXrMean = thr_copy.partition_D(gMean) if const_expr(mMean is not None) else None + tXcX = thr_copy.partition_S(cX)[(0, None), None, None] + + # Fragments for gmem->rmem. + tXrW = cute.make_fragment_like(tXgW) + tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None + tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)] + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=False) + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + None if is_even_N else predicate_k(thr_copy.partition_S(cX), limit=shape[1]) + ) + row = tXcX[0][0] + if const_expr(not self.direct_gmem): + if row < shape[0]: + cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + + if const_expr(not delay_w_load): + cute.copy(copy_atom_load_WB, tXgW, tXrW, pred=tXpX) + if const_expr(mB is not None): + cute.copy(copy_atom_load_WB, tXgB, tXrB, pred=tXpX) + + if const_expr(not self.direct_gmem): + cute.arch.cp_async_wait_group(0) + cute.autovec_copy(tXsX, tXrX) + else: + if row < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + x = tXrX.load().to(Float32) + if const_expr(self.cluster_n == 1 and self.N in (4096, 6144, 7168, 8192)): + # SM100 tuning for DSv3 hidden sizes: + # Compute (sum_x, sum_x2) together so we can derive mean + variance + # without a second reduction pass (and without re-materializing + # x-mean for the variance reduction). + sum_x = x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0) + sum_x2 = (x * x).reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0 + ) + sum_x = warp_reduce( + sum_x, + operator.add, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + sum_x2 = warp_reduce( + sum_x2, + operator.add, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + warps_per_row, cluster_n = reduction_buffer.shape[1] + if const_expr(warps_per_row > 1 or cluster_n > 1): + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if lane_idx == 0: + reduction_buffer[row_idx, col_idx, 0] = sum_x + reduction_buffer[row_idx, col_idx, 1] = sum_x2 + cute.arch.barrier() + block_sum_x = 0.0 + block_sum_x2 = 0.0 + if lane_idx < warps_per_row: + block_sum_x = reduction_buffer[row_idx, lane_idx, 0] + block_sum_x2 = reduction_buffer[row_idx, lane_idx, 1] + sum_x = warp_reduce(block_sum_x, operator.add) + sum_x2 = warp_reduce(block_sum_x2, operator.add) + mean = sum_x / shape[1] + var = sum_x2 / shape[1] - mean * mean + var = cute.arch.fmax(var, 0.0) + rstd = cute.math.rsqrt(var + eps, fastmath=True) + else: + sum_x = row_reduce( + x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + hook_fn=( + cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None + ), + ) + mean = sum_x / shape[1] + + if const_expr(reload_from == "smem"): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + elif const_expr(reload_from == "gmem"): + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + x = tXrX.load().to(Float32) + + sum_sq_x_sub_mean = row_reduce( + (x - mean) * (x - mean), + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 1], + mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True) + + if const_expr(mRstd is not None): + if ( + tXcX[0][1] == 0 + and row < shape[0] + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXrRstd[0] = rstd + + if const_expr(mMean is not None): + if ( + tXcX[0][1] == 0 + and row < shape[0] + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXrMean[0] = mean + + if const_expr(delay_w_load): + cute.copy(copy_atom_load_WB, tXgW, tXrW, pred=tXpX) + if const_expr(mB is not None): + cute.copy(copy_atom_load_WB, tXgB, tXrB, pred=tXpX) + + if const_expr(reload_from == "smem"): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + elif const_expr(reload_from == "gmem"): + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + x = tXrX.load().to(Float32) + + x_hat = (x - mean) * rstd + w = tXrW.load().to(Float32) + y = x_hat * w + if const_expr(mB is not None): + b = tXrB.load().to(Float32) + y = y + b + + tXrO.store(y.to(tXrO.element_type)) + if row < shape[0]: + cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + reload_from: cutlass.Constexpr, + delay_w_load: cutlass.Constexpr, + ): + self._kernel_impl( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + tv_layout, + tiler_mn, + reload_from, + delay_w_load, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + eps: Float32, + ): + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mW.element_type.width, + mB.element_type.width if const_expr(mB is not None) else 0, + mO.element_type.width, + mRstd.element_type.width if const_expr(mRstd is not None) else 0, + mMean.element_type.width if const_expr(mMean is not None) else 0, + ) + ) + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + default_copy_bits_x = vecsize * mX.element_type.width + num_copy_bits_x = ( + int(self.copy_bits_x) + if const_expr(self.copy_bits_x is not None) + else default_copy_bits_x + ) + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits_x) + self._kernel_impl( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + tv_layout, + tiler_mn, + const_expr(self.reload_from), + const_expr(self.delay_w_load), + ) + + +# ----------------------------------------------------------------------------- +# Public Python API +# ----------------------------------------------------------------------------- + + +def layernorm( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + eps: float = 1e-6, + return_rstd: bool = False, + return_mean: bool = False, +): + """ + LayerNorm forward pass using the SM100 CuteDSL kernel. + + Args: + x: Input tensor of shape (M, N). + weight: Scale parameter of shape (N,), typically fp32. + bias: Optional bias parameter of shape (N,). + eps: Small value for numerical stability. + return_rstd: Whether to return per-row reciprocal std (shape (M,)). + return_mean: Whether to return per-row mean (shape (M,)). + """ + assert x.is_cuda and weight.is_cuda, "x and weight must be CUDA tensors" + assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." + assert weight.dim() == 1, "weight must be 1D" + assert x.shape[1] == weight.shape[0], "Last dim of x must match weight.size(0)" + if bias is not None: + assert bias.is_cuda, "bias must be on CUDA" + assert bias.dim() == 1 and bias.shape[0] == weight.shape[0], ( + "bias must be 1D and match weight" + ) + + M, N = x.shape + dtype = TORCH2CUTE_DTYPE[x.dtype] + + rstd = torch.empty(M, device=x.device, dtype=torch.float32) if return_rstd else None + mean = torch.empty(M, device=x.device, dtype=torch.float32) if return_mean else None + + # Fast path: bypass DLPack conversions when the inputs are in the common + # contiguous row-major layout and weights/bias are fp32 (Quack-style). + if _can_use_ptr_path(x, weight, bias): + out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + _layernorm_forward_ptr_into( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + eps=eps, + ) + if return_mean and return_rstd: + return out, rstd, mean + if return_rstd and not return_mean: + return out, rstd + if return_mean and not return_rstd: + return out, mean + return out + + out = torch.empty_like(x) + mX = _convert_row_major(x) + mO = _convert_row_major(out) + + # Weight/bias live in feature dimension (N). + mW = convert_from_dlpack_cute( + weight.detach(), + leading_dim=0, + alignment=16, + divisibility=128 // cutlass.Float32.width, + ) + mB = ( + convert_from_dlpack_cute( + bias.detach(), + leading_dim=0, + alignment=16, + divisibility=128 // cutlass.Float32.width, + ) + if bias is not None + else None + ) + + mRstd = ( + from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if rstd is not None + else None + ) + mMean = ( + from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if mean is not None + else None + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + key = (N, dtype, mB is not None, mRstd is not None, mMean is not None) + compiled = _COMPILE_CACHE.get(key) + if compiled is None: + op = LayerNormSM100(dtype, N) + compiled = cute.compile( + op, + mX, + mW, + mB, + mO, + mRstd, + mMean, + stream, + Float32(eps), + ) + _COMPILE_CACHE[key] = compiled + + compiled( + mX, + mW, + mB, + mO, + mRstd, + mMean, + stream, + Float32(eps), + ) + + if return_mean and return_rstd: + return out, rstd, mean + if return_rstd and not return_mean: + return out, rstd + if return_mean and not return_rstd: + return out, mean + return out + + +def _can_use_ptr_path(x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> bool: + """Return True if we can safely use the pointer-based fast path. + + This is intentionally conservative: we target the common inference-like + layout (2D row-major with stride(1)==1) and Quack-style fp32 weights. + """ + if not x.is_cuda or x.dim() != 2: + return False + if x.stride(1) != 1: + return False + if not weight.is_cuda or weight.dim() != 1: + return False + if weight.dtype != torch.float32: + return False + if not weight.is_contiguous(): + return False + if bias is not None: + if not bias.is_cuda or bias.dim() != 1: + return False + if bias.dtype != torch.float32: + return False + if not bias.is_contiguous(): + return False + # Require 16B alignment for 128-bit vector copies (matches Quack's assumed_align=16). + if (x.data_ptr() % 16) != 0: + return False + if (weight.data_ptr() % 16) != 0: + return False + if bias is not None and (bias.data_ptr() % 16) != 0: + return False + # The kernel uses 128-bit vectorized loads; require the leading dimension + # to preserve 16B alignment for every row start. + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + divby = 128 // dtype_x.width + if (x.stride(0) % divby) != 0: + return False + return True + + +def _layernorm_forward_ptr_into( + *, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + out: Tensor, + rstd: Optional[Tensor], + mean: Optional[Tensor], + eps: float, +) -> None: + """Launch the pointer-based LayerNorm kernel into preallocated outputs.""" + assert x.is_cuda and x.dim() == 2 + M, N = x.shape + assert weight.is_cuda and weight.dim() == 1 and weight.shape[0] == N + if bias is not None: + assert bias.is_cuda and bias.dim() == 1 and bias.shape[0] == N + assert out.is_cuda and out.shape == x.shape and out.dtype == x.dtype + assert out.stride() == x.stride(), "Pointer path expects out to match x strides" + if rstd is not None: + assert rstd.is_cuda and rstd.shape == (M,) and rstd.dtype == torch.float32 + if mean is not None: + assert mean.is_cuda and mean.shape == (M,) and mean.dtype == torch.float32 + + device_index = x.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + # Keep the pointer path aligned with Quack's LayerNorm schedule: + # - <=128b vectorization (cp.async-compatible) + # - shared-memory staging for X (gmem->smem->rmem) to amortize global latency + direct_gmem = False + copy_bits_x: Optional[int] = None + assumed_align_xo = 16 + + # DSv3 hidden sizes are often latency-bound on small M. For these N buckets, + # a direct-GMEM schedule (skip gmem->smem cp.async) can reduce overhead. + # + # Keep the Quack-like staged path for large M where cp.async overlap tends to win. + if dtype_x.width == 16: + # DSv3 default hidden size (7168) is a common inference hot shape and + # benefits from the lower-overhead direct-GMEM path on this SM100. + if N == 7168 and M <= 65536: + direct_gmem = True + elif N == 8192 and M <= 16384: + direct_gmem = True + + # DSv3 smallest point (M=4096, N=7168) is latency-sensitive. Increasing + # per-row parallelism improves the reduction path and consistently beats + # Quack on this machine. + tpr_override: Optional[int] = None + nt_override: Optional[int] = None + if dtype_x.width == 16 and N == 7168 and M <= 4096: + tpr_override = 224 + nt_override = 224 + + # NOTE: We previously experimented with a direct-GMEM + 256b vectorized + # schedule for N=4096, but it was consistently slower on this GB200. + # Keep the pointer path on the Quack-like staged (cp.async) schedule. + key = ( + "ptr", + int(N), + dtype_x, + bias is not None, + rstd is not None, + mean is not None, + bool(direct_gmem), + int(copy_bits_x) if copy_bits_x is not None else None, + tpr_override, + nt_override, + int(assumed_align_xo), + int(device_index), + ) + compiled = _PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = LayerNormSM100( + dtype_x, + int(N), + copy_bits_x=copy_bits_x, + direct_gmem=direct_gmem, + ) + if tpr_override is not None: + op._tpr_override = tpr_override # type: ignore[attr-defined] + if nt_override is not None: + op._nt_override = nt_override # type: ignore[attr-defined] + ptr_x = rt.make_ptr( + dtype_x, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_xo, + ) + ptr_out = rt.make_ptr( + dtype_x, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_xo, + ) + ptr_w = rt.make_ptr( + cutlass.Float32, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + ptr_b = ( + rt.make_ptr( + cutlass.Float32, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if rstd is not None + else None + ) + ptr_mean = ( + rt.make_ptr( + cutlass.Float32, + mean.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if mean is not None + else None + ) + ld = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + Int32(int(M)), + ld, + stream, + Float32(float(eps)), + ) + _PTR_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_layernorm_launcher( + compiled=compiled, + N=int(N), + dtype_x=dtype_x, + has_bias=bias is not None, + has_rstd=rstd is not None, + has_mean=mean is not None, + device_index=int(device_index), + stream_handle=stream_handle, + assumed_align_xo=int(assumed_align_xo), + eps=float(eps), + ) + ld_val = int(x.stride(0)) + if launcher is not None: + launcher.launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=int(M), + ld=ld_val, + eps=float(eps), + ) + return + + ptr_x = rt.make_ptr( + dtype_x, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_xo, + ) + ptr_out = rt.make_ptr( + dtype_x, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_xo, + ) + ptr_w = rt.make_ptr( + cutlass.Float32, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + ptr_b = ( + rt.make_ptr( + cutlass.Float32, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if rstd is not None + else None + ) + ptr_mean = ( + rt.make_ptr( + cutlass.Float32, + mean.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if mean is not None + else None + ) + ld = Int32(ld_val) + compiled( + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + Int32(int(M)), + ld, + stream, + Float32(float(eps)), + ) + + +def layernorm_ref( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + eps: float = 1e-6, +) -> Tensor: + """ + Reference LayerNorm implemented via torch.nn.functional.layer_norm. + """ + x_f32 = x.float() + w = weight.float() + b = bias.float() if bias is not None else None + y = torch.nn.functional.layer_norm(x_f32, (x.shape[-1],), w, b, eps) + return y.to(x.dtype) + + +def _as_2d(x: Tensor) -> Tuple[Tensor, Tuple[int, ...]]: + if x.dim() == 2: + return x, x.shape + original_shape = x.shape + M = int(torch.prod(torch.tensor(original_shape[:-1])).item()) + N = original_shape[-1] + return x.reshape(M, N), original_shape + + +def _restore_shape(x: Tensor, shape: Tuple[int, ...]) -> Tensor: + return x.reshape(shape) + + +@cute.kernel +def _layernorm_backward_dx_kernel( + mX: cute.Tensor, + mW: cute.Tensor, + mdO: cute.Tensor, + mRstd: cute.Tensor, + mMean: cute.Tensor, + mdX: cute.Tensor, +): + """ + Simple CTA-per-row LayerNorm backward kernel for dx only. + + Each block processes one row of shape (N,), using block_threads threads. + It performs two passes over the row: + 1) Compute mean_wdy and mean_xhat_wdy in fp32. + 2) Compute dx using the standard LayerNorm backward formula: + dx = rstd * (wdy - mean_wdy - x_hat * mean_xhat_wdy), + where wdy = dy * gamma and x_hat = (x - mean) * rstd. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + block_threads = const_expr(256) + shape = mX.shape + M = shape[0] + N = shape[1] + + row = bidx + if row < M: + # Shared buffers for warp-level reductions across the block. + smem = cutlass.utils.SmemAllocator() + num_warps = const_expr(block_threads // cute.arch.WARP_SIZE) + warp_sums_layout = cute.make_layout((num_warps,), stride=(1,)) + warp_sums_wdy = smem.allocate_tensor( + Float32, warp_sums_layout, byte_alignment=4 + ) + warp_sums_xhatwdy = smem.allocate_tensor( + Float32, warp_sums_layout, byte_alignment=4 + ) + + lane = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + + rstd_val = mRstd[row].to(Float32) + mean_val = mMean[row].to(Float32) + + # Pass 1: compute local partial sums of wdy and x_hat*wdy. + local_wdy = Float32(0.0) + local_xhatwdy = Float32(0.0) + for col in cutlass.range(tidx, N, block_threads): + x_val = mX[row, col].to(Float32) + dy_val = mdO[row, col].to(Float32) + gamma = mW[col].to(Float32) + x_mu = x_val - mean_val + x_hat = x_mu * rstd_val + wdy = dy_val * gamma + local_wdy += wdy + local_xhatwdy += x_hat * wdy + + # Warp-level reduction, then block-level reduction via shared memory. + red_op = operator.add # type: ignore[assignment] + local_wdy = warp_reduce(local_wdy, red_op) + local_xhatwdy = warp_reduce(local_xhatwdy, red_op) + + if lane == 0: + warp_sums_wdy[warp_idx] = local_wdy + warp_sums_xhatwdy[warp_idx] = local_xhatwdy + + cute.arch.barrier() + + total_wdy = Float32(0.0) + total_xhatwdy = Float32(0.0) + if warp_idx == 0 and lane == 0: + for wi in cutlass.range_constexpr(num_warps): + total_wdy += warp_sums_wdy[wi] + total_xhatwdy += warp_sums_xhatwdy[wi] + # Store totals back into first slots for broadcast. + warp_sums_wdy[0] = total_wdy + warp_sums_xhatwdy[0] = total_xhatwdy + + cute.arch.barrier() + + total_wdy = warp_sums_wdy[0] + total_xhatwdy = warp_sums_xhatwdy[0] + inv_N = Float32(1.0 / float(N)) + mean_wdy = total_wdy * inv_N + mean_xhatwdy = total_xhatwdy * inv_N + + # Pass 2: compute dx and write back. + for col in cutlass.range(tidx, N, block_threads): + x_val = mX[row, col].to(Float32) + dy_val = mdO[row, col].to(Float32) + gamma = mW[col].to(Float32) + x_mu = x_val - mean_val + x_hat = x_mu * rstd_val + wdy = dy_val * gamma + dx_val = (wdy - mean_wdy - x_hat * mean_xhatwdy) * rstd_val + mdX[row, col] = dx_val.to(mdX.element_type) + + +@cute.jit +def _layernorm_backward_dx( + mX: cute.Tensor, + mW: cute.Tensor, + mdO: cute.Tensor, + mRstd: cute.Tensor, + mMean: cute.Tensor, + mdX: cute.Tensor, + stream: cuda.CUstream, +) -> None: + """ + JIT wrapper that launches the dx-only LayerNorm backward kernel. + One CTA processes one row of length N with 256 threads. + """ + M = mX.shape[0] + _layernorm_backward_dx_kernel( + mX, + mW, + mdO, + mRstd, + mMean, + mdX, + ).launch( + grid=[M, 1, 1], + block=[256, 1, 1], + stream=stream, + ) + + +@cute.kernel +def _layernorm_backward_param_kernel( + mX: cute.Tensor, + mdO: cute.Tensor, + mRstd: cute.Tensor, + mMean: cute.Tensor, + mdW_partial: Optional[cute.Tensor], + mdB_partial: Optional[cute.Tensor], + num_blocks: Int32, +) -> None: + """ + Parameter-gradient kernel for LayerNorm. + + Each CTA accumulates partial dweight/dbias over a stripe of rows: + - Grid dim X: num_blocks (sm_count-style persistent CTAs). + - Threads in a CTA partition the N dimension. + - For each assigned column, a thread streams over rows + row = blockIdx.x, blockIdx.x + num_blocks, ... + + This mirrors the persistent-CTA pattern used by RMSNorm backward, + but uses a simpler per-thread accumulation since columns are + independent. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + block_threads = const_expr(256) + M = mX.shape[0] + N = mX.shape[1] + + if bidx < num_blocks: + for col in cutlass.range(tidx, N, block_threads): + dw_local = Float32(0.0) + db_local = Float32(0.0) + for row in cutlass.range(bidx, M, num_blocks): + x_val = mX[row, col].to(Float32) + dy_val = mdO[row, col].to(Float32) + rstd_val = mRstd[row].to(Float32) + mean_val = mMean[row].to(Float32) + x_mu = x_val - mean_val + x_hat = x_mu * rstd_val + dw_local += dy_val * x_hat + db_local += dy_val + + if const_expr(mdW_partial is not None): + mdW_partial[bidx, col] = dw_local + if const_expr(mdB_partial is not None): + mdB_partial[bidx, col] = db_local + + +@cute.jit +def _layernorm_backward_param( + mX: cute.Tensor, + mdO: cute.Tensor, + mRstd: cute.Tensor, + mMean: cute.Tensor, + mdW_partial: Optional[cute.Tensor], + mdB_partial: Optional[cute.Tensor], + num_blocks: Int32, + stream: cuda.CUstream, +) -> None: + """ + JIT wrapper that launches the parameter-gradient kernel. + """ + _layernorm_backward_param_kernel( + mX, + mdO, + mRstd, + mMean, + mdW_partial, + mdB_partial, + num_blocks, + ).launch( + grid=[num_blocks, 1, 1], + block=[256, 1, 1], + stream=stream, + ) + + +def _layernorm_backward_dx_sm100( + dout_2d: Tensor, + x_2d: Tensor, + weight: Tensor, + rstd_1d: Tensor, + mean_1d: Tensor, + dx_2d: Tensor, +) -> None: + """ + Host-side helper to run the dx-only LayerNorm backward kernel. + """ + M, N = x_2d.shape + assert dout_2d.shape == (M, N) + assert rstd_1d.numel() == M + assert mean_1d.numel() == M + + dtype = TORCH2CUTE_DTYPE[x_2d.dtype] + + mX = _convert_row_major(x_2d) + mdO = _convert_row_major(dout_2d) + mdX = _convert_row_major(dx_2d) + + mW = convert_from_dlpack_cute( + weight.detach(), + leading_dim=0, + alignment=16, + divisibility=128 // cutlass.Float32.width, + ) + mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + key = (N, dtype) + compiled = _BWD_DX_COMPILE_CACHE.get(key) + if compiled is None: + compiled = cute.compile( + _layernorm_backward_dx, + mX, + mW, + mdO, + mRstd, + mMean, + mdX, + stream, + ) + _BWD_DX_COMPILE_CACHE[key] = compiled + + compiled( + mX, + mW, + mdO, + mRstd, + mMean, + mdX, + stream, + ) + + +def _layernorm_backward_params_sm100( + dout_2d: Tensor, + x_2d: Tensor, + rstd_1d: Tensor, + mean_1d: Tensor, + dw_partial: Optional[Tensor], + db_partial: Optional[Tensor], + sm_count: int, +) -> None: + """ + Host-side helper to run the parameter-gradient kernel that populates + dw_partial / db_partial of shape (sm_count, N). + """ + M, N = x_2d.shape + assert dout_2d.shape == (M, N) + assert rstd_1d.numel() == M + assert mean_1d.numel() == M + if dw_partial is None and db_partial is None: + return + + dtype = TORCH2CUTE_DTYPE[x_2d.dtype] + + mX = _convert_row_major(x_2d) + mdO = _convert_row_major(dout_2d) + mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + + mdW_partial = ( + from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if dw_partial is not None + else None + ) + mdB_partial = ( + from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if db_partial is not None + else None + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + has_bias = db_partial is not None + key = (N, dtype, has_bias) + compiled = _BWD_PARAM_COMPILE_CACHE.get(key) + if compiled is None: + compiled = cute.compile( + _layernorm_backward_param, + mX, + mdO, + mRstd, + mMean, + mdW_partial, + mdB_partial, + Int32(sm_count), + stream, + ) + _BWD_PARAM_COMPILE_CACHE[key] = compiled + + compiled( + mX, + mdO, + mRstd, + mMean, + mdW_partial, + mdB_partial, + Int32(sm_count), + stream, + ) + + +def layernorm_backward( + dout: Tensor, + x: Tensor, + weight: Tensor, + rstd: Tensor, + mean: Tensor, + bias: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """ + LayerNorm backward implemented in CuteDSL / CUTLASS. + + Computes gradients w.r.t. input, weight, and optional bias using + two kernels: + - A dx kernel (CTA-per-row) that streams over N. + - A parameter-gradient kernel that accumulates dw/db over a + persistent grid of CTAs across the M dimension. + """ + assert x.shape == dout.shape, "x and dout must have the same shape" + assert x.is_cuda and dout.is_cuda, "x and dout must be CUDA tensors" + assert weight.dim() == 1, "weight must be 1D" + if bias is not None: + assert bias.dim() == 1, "bias must be 1D" + + x_2d, orig_shape = _as_2d(x) + dout_2d, _ = _as_2d(dout) + M, N = x_2d.shape + + # Flatten to 2D for the kernels. + mean_flat = mean.view(M) + rstd_flat = rstd.view(M) + + dx_2d = torch.empty_like(x_2d) + _layernorm_backward_dx_sm100( + dout_2d, + x_2d, + weight, + rstd_flat, + mean_flat, + dx_2d, + ) + + device = x.device + sm_count = get_sm_count(N, device, M=M, dtype=x.dtype) + + dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) + db_partial = ( + torch.empty(sm_count, N, device=device, dtype=torch.float32) + if bias is not None + else None + ) + + _layernorm_backward_params_sm100( + dout_2d, + x_2d, + rstd_flat, + mean_flat, + dw_partial, + db_partial, + sm_count, + ) + + dweight = dw_partial.sum(dim=0).to(weight.dtype) + dbias = db_partial.sum(dim=0).to(bias.dtype) if bias is not None else None + + dx = _restore_shape(dx_2d, orig_shape) + return dx, dweight, dbias + + +if __name__ == "__main__": + # Allow direct execution for a quick functional check. + if not torch.cuda.is_available(): + print("CUDA not available; LayerNormSM100 test skipped.") + raise SystemExit(0) + + device = "cuda" + M, N = 2048, 4096 + dtype = torch.bfloat16 + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=torch.float32) + b = torch.randn(N, device=device, dtype=torch.float32) + + y_ref = layernorm_ref(x, w, b) + y, rstd, mean = layernorm(x, w, b, return_rstd=True, return_mean=True) + torch.testing.assert_close( + y, + y_ref, + atol=5e-2 if dtype != torch.float32 else 1e-5, + rtol=5e-2 if dtype != torch.float32 else 1e-5, + ) + + print("LayerNormSM100 forward correctness check passed.") diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py new file mode 100644 index 0000000..c7402d8 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py @@ -0,0 +1,1444 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Lightweight local clone of the small subset of Quack helpers that the SM100 +RMSNorm CuteDSL kernels depend on. + +This module intentionally avoids importing the `quack` package so that +KernelAgent Oink SM100 kernels can run without Quack installed, while keeping +numerical behaviour and performance identical to the reference kernels. +""" + +from __future__ import annotations + +import math +import operator +import importlib.metadata +import re +from functools import partial +from typing import Callable, Optional, Tuple, Type + +import cuda.bindings.driver as cuda # type: ignore +import torch +from torch import Tensor + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.runtime import from_dlpack +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm, nvvm, vector + + +def _parse_version_tuple(version: str) -> tuple[int, int, int]: + parts = version.split(".") + nums: list[int] = [] + for part in parts[:3]: + match = re.match(r"^(\d+)", part) + nums.append(int(match.group(1)) if match is not None else 0) + while len(nums) < 3: + nums.append(0) + return nums[0], nums[1], nums[2] + + +def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]: + try: + return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl")) + except Exception: + return None + + +_CUTLASS_DSL_VERSION = _cutlass_dsl_version() +# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around +# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the +# older signature for <4.3.4, but switch to a 4.3.4+ compatible signature when +# we detect 4.3.4+ (or when version detection is unavailable). +_KERNEL_ACCEPTS_LAYOUT_ARGS = ( + _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) +) + +# Cache device properties lookups (notably `multi_processor_count`) since some +# dispatch paths call `get_sm_count` inside tight benchmark loops. +_DEVICE_NUM_SMS_CACHE: dict[int, int] = {} + + +def get_num_sms(device: torch.device) -> int: + """Return the number of SMs for a CUDA device (cached).""" + device_index = device.index + if device_index is None: + device_index = torch.cuda.current_device() + device_index = int(device_index) + cached = _DEVICE_NUM_SMS_CACHE.get(device_index) + if cached is not None: + return cached + num_sms = int(torch.cuda.get_device_properties(device_index).multi_processor_count) + _DEVICE_NUM_SMS_CACHE[device_index] = num_sms + return num_sms + + +# ------------------------- +# Dtype mapping (from quack.cute_dsl_utils) +# ------------------------- + +TORCH2CUTE_DTYPE = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + + +# ------------------------- +# Tensor conversion helpers (from quack.utils) +# ------------------------- + + +def convert_from_dlpack( + x: Tensor, + leading_dim: int, + alignment: int = 16, + divisibility: int = 1, +) -> cute.Tensor: + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, + stride_order=x.dim_order(), + divisibility=divisibility, + ) + ) + + +# ------------------------- +# SM90/SM100 cluster helpers (from quack.utils) +# ------------------------- + + +@dsl_user_op +def elem_pointer( + x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None +) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote( + val: float | Float32 | Int32 | cutlass.Int64, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.typing.Int, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, + peer_cta_rank_in_cluster, + loc=loc, + ip=ip, + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, + peer_cta_rank_in_cluster, + loc=loc, + ip=ip, + ).ir_value() + if const_expr(isinstance(val, float)): + val = Float32(val) + assert isinstance(val, (Float32, Int32, cutlass.Int64)), ( + "val must be Float32, Int32, or Int64" + ) + suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] + constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] + llvm.inline_asm( + None, + [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32], + f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];", + f"r,{constraint},r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def atomic_add_f32( + a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> Float32: + """Atomic add into global memory (float32).""" + return nvvm.atomicrmw( + res=T.f32(), + op=nvvm.AtomicOpKind.FADD, + ptr=gmem_ptr.llvm_ptr, + a=Float32(a).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@cute.jit +def atomic_add_tensor_f32( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, +) -> None: + """Atomic-add a register fragment into a GMEM tile (float32).""" + if const_expr(pred is None): + for i in cutlass.range_constexpr(cute.size(src.shape)): + coord = cute.idx2crd(i, src.shape) + atomic_add_f32(src[i], elem_pointer(dst, coord)) + else: + for i in cutlass.range_constexpr(cute.size(src.shape)): + # CuTeDSL 4.3.4+ disallows introducing new tuple-typed values inside + # a dynamic `if`. Compute `coord` unconditionally, then predicate the + # atomic update. + coord = cute.idx2crd(i, src.shape) + if pred[i]: + atomic_add_f32(src[i], elem_pointer(dst, coord)) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if". + tApA = cute.make_fragment( + cute.make_layout( + ( + cute.size(tAcA, mode=[0, 1]), + cute.size(tAcA, mode=[1]), + cute.size(tAcA, mode=[2]), + ), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less( + tAcA[(0, rest_v), 0, rest_k][1], limit + ) + return tApA + + +@dsl_user_op +def domain_offset_i64( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(tensor.stride) + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + assert isinstance(tensor.iterator, cute.Pointer) + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def coord_offset_i64( + idx: cute.typing.Int, + tensor: cute.Tensor, + dim: int, + *, + loc=None, + ip=None, +) -> cute.Tensor: + offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim]) + assert isinstance(tensor.iterator, cute.Pointer) + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@cute.jit +def fill_oob( + tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cutlass.Numeric +) -> None: + """Fill out-of-bounds values in shared memory tensor.""" + tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0]) + tXrX_fill.fill(fill_value) + for rest_v in cutlass.range_constexpr(const_expr(tXsX.shape[0][1])): + for rest_k in cutlass.range_constexpr(const_expr(tXsX.shape[2])): + if const_expr(tXpX is not None): + if not tXpX[rest_v, 0, rest_k]: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + else: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + + +@dsl_user_op +def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64: + """Pack two f32 values into a single i64. + + This mirrors quack.utils.f32x2_to_i64 and is used by online_softmax_reduce + to store (max, sum_exp) pairs in an Int64 reduction buffer. + """ + vec_f32x2 = vector.from_elements( + T.vector(2, T.f32()), + (a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)), + loc=loc, + ip=ip, + ) + vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2, loc=loc, ip=ip) + res = cutlass.Int64( + vector.extract( + vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip + ) + ) + return res + + +@dsl_user_op +def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + """Unpack a single i64 into two f32 values, inverse of f32x2_to_i64.""" + vec_i64x1 = vector.from_elements( + T.vector(1, T.i64()), + (c.ir_value(loc=loc, ip=ip),), + loc=loc, + ip=ip, + ) + vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1, loc=loc, ip=ip) + res0 = Float32( + vector.extract( + vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip + ) + ) + res1 = Float32( + vector.extract( + vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip + ) + ) + return res0, res1 + + +# ------------------------- +# Reduction helpers (from quack.reduce) +# ------------------------- + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + if cutlass.const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_fragment(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + return cute.arch.warp_reduction(val, op, threads_in_group=width) + + +@cute.jit +def block_reduce( + val: cute.Numeric, + op: Callable, + reduction_buffer: cute.Tensor, + init_val: cute.Numeric = 0.0, +) -> cute.Numeric: + """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row).""" + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + warps_per_row = cute.size(reduction_buffer.shape[1]) + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if lane_idx == 0: + reduction_buffer[row_idx, col_idx] = val + cute.arch.barrier() + block_reduce_val = init_val + if lane_idx < warps_per_row: + block_reduce_val = reduction_buffer[row_idx, lane_idx] + return warp_reduce(block_reduce_val, op) + + +@cute.jit +def cluster_reduce( + val: cute.Numeric, + op: Callable, + reduction_buffer: cute.Tensor, + mbar_ptr: cute.Pointer, + init_val: cute.Numeric = 0.0, + phase: Optional[cutlass.Int32] = None, +) -> cute.Numeric: + """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n)).""" + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if warp_idx == 0: + with cute.arch.elect_one(): + num_warps = rows_per_block * warps_per_row + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr, + num_warps * cluster_n * reduction_buffer.element_type.width // 8, + ) + if lane_idx < cluster_n: + store_shared_remote( + val, + elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0) + block_reduce_val = init_val + num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) + for i in cutlass.range_constexpr(num_iter): + idx = lane_idx + i * cute.arch.WARP_SIZE + if idx < cute.size(reduction_buffer, mode=[1]): + block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx]) + return warp_reduce(block_reduce_val, op) + + +@cute.jit +def block_or_cluster_reduce( + val: cute.Numeric, + op: Callable, + reduction_buffer: cute.Tensor, + mbar_ptr: Optional[cute.Pointer], + phase: Optional[cutlass.Int32] = None, + init_val: cute.Numeric = 0.0, +) -> cute.Numeric: + """Perform either block or cluster reduction based on whether mbar_ptr is provided.""" + if cutlass.const_expr(mbar_ptr is None): + return block_reduce(val, op, reduction_buffer, init_val=init_val) + return cluster_reduce( + val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase + ) + + +@cute.jit +def row_reduce( + x: cute.TensorSSA | cute.Numeric, + op: cute.ReductionOp, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: Optional[cute.Tensor] = None, + mbar_ptr: Optional[cute.Pointer] = None, + phase: Optional[cutlass.Int32] = None, + init_val: cute.Numeric = 0.0, + hook_fn: Optional[Callable] = None, +) -> cute.Numeric: + """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n)).""" + if cutlass.const_expr(isinstance(x, cute.TensorSSA)): + val = x.reduce(op, init_val=init_val, reduction_profile=0) + else: + val = x + warp_op = { + cute.ReductionOp.ADD: operator.add, + cute.ReductionOp.MAX: cute.arch.fmax + if cutlass.const_expr(x.dtype == Float32) + else max, + cute.ReductionOp.MIN: min, + cute.ReductionOp.MUL: operator.mul, + }[op] + val = warp_reduce( + val, + warp_op, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + if cutlass.const_expr(hook_fn is not None): + hook_fn() + if cutlass.const_expr(reduction_buffer is not None): + warps_per_row, cluster_n = reduction_buffer.shape[1] + assert cluster_n == 1 or mbar_ptr is not None, ( + "mbar_ptr must be provided for cluster reduction" + ) + if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1): + val = block_or_cluster_reduce( + val, + warp_op, + reduction_buffer, + mbar_ptr, + phase=phase, + init_val=init_val, + ) + return val + + +@cute.jit +def row_reduce_add( + x: cute.TensorSSA | cute.Numeric, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: Optional[cute.Tensor] = None, + mbar_ptr: Optional[cute.Pointer] = None, + phase: Optional[cutlass.Int32] = None, + init_val: cute.Numeric = 0.0, + hook_fn: Optional[Callable] = None, +) -> cute.Numeric: + """Specialized row_reduce for ADD reductions. + + This mirrors row_reduce but hardcodes the ADD operation so we avoid + dynamic dispatch on the reduction op. It is used by bandwidth-bound + kernels like RMSNorm backward where the reduction is always ADD in + Float32. + """ + if cutlass.const_expr(isinstance(x, cute.TensorSSA)): + val = x.reduce(cute.ReductionOp.ADD, init_val=init_val, reduction_profile=0) + else: + val = x + val = warp_reduce( + val, + operator.add, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + if cutlass.const_expr(hook_fn is not None): + hook_fn() + if cutlass.const_expr(reduction_buffer is not None): + warps_per_row, cluster_n = reduction_buffer.shape[1] + assert cluster_n == 1 or mbar_ptr is not None, ( + "mbar_ptr must be provided for cluster reduction" + ) + if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1): + val = block_or_cluster_reduce( + val, + operator.add, + reduction_buffer, + mbar_ptr, + phase=phase, + init_val=init_val, + ) + return val + + +@cute.jit +def online_softmax_reduce( + x: cute.TensorSSA, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: Optional[cute.Tensor] = None, + mbar_ptr: Optional[cute.Pointer] = None, + hook_fn: Optional[Callable] = None, + phase: Optional[cutlass.Int32] = None, + return_exp_x: bool = False, +) -> tuple[Float32, Float32, Optional[cute.TensorSSA]]: + """Online softmax reduction over a row. + + This mirrors quack.reduce.online_softmax_reduce and computes: + - max_x: row-wise maximum of x + - sum_exp_x: row-wise sum of exp(x - max_x) + - exp_x (optional): per-element exp(x - max_x_final) if return_exp_x is True + """ + assert x.dtype == Float32, "x must be of type Float32" + # reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2) + max_x = warp_reduce( + x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + cute.arch.fmax, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + log2_e = math.log2(math.e) + exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True) + sum_exp_x = warp_reduce( + exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0), + operator.add, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + if cutlass.const_expr(hook_fn is not None): + hook_fn() + if cutlass.const_expr(reduction_buffer is not None): + rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape + assert cluster_n == 1 or mbar_ptr is not None, ( + "mbar_ptr must be provided for cluster reduction" + ) + if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1): + assert reduction_buffer.element_type == cutlass.Int64, ( + "reduction_buffer must be of type Int64" + ) + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if cutlass.const_expr(mbar_ptr is None): + if lane_idx == 0: + reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x) + cute.arch.barrier() + max_x_single_warp = -Float32.inf + sum_exp_x = 0.0 + if lane_idx < warps_per_row: + max_x_single_warp, sum_exp_x = i64_to_f32x2( + reduction_buffer[row_idx, lane_idx] + ) + max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax) + sum_exp_x *= cute.math.exp( + max_x_single_warp - max_x_final, fastmath=True + ) + sum_exp_x = warp_reduce(sum_exp_x, operator.add) + if cutlass.const_expr(return_exp_x): + exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True) + max_x = max_x_final + else: + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + if warp_idx == 0: + with cute.arch.elect_one(): + num_warps = rows_per_block * warps_per_row + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr, + num_warps + * cluster_n + * reduction_buffer.element_type.width + // 8, + ) + if lane_idx < cluster_n: + store_shared_remote( + f32x2_to_i64(max_x, sum_exp_x), + elem_pointer( + reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster)) + ), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + cute.arch.mbarrier_wait( + mbar_ptr, phase=phase if phase is not None else 0 + ) + num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) + max_x_single_warp = cute.make_fragment(num_iter, Float32) + max_x_single_warp.fill(-Float32.inf) + sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32) + sum_exp_x_single_warp.fill(0.0) + for i in cutlass.range_constexpr(num_iter): + idx = lane_idx + i * cute.arch.WARP_SIZE + if idx < cute.size(reduction_buffer, mode=[1]): + max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2( + reduction_buffer[row_idx, idx] + ) + max_x_final = max_x_single_warp.load().reduce( + cute.ReductionOp.MAX, + init_val=-Float32.inf, + reduction_profile=0, + ) + max_x_final = warp_reduce(max_x_final, cute.arch.fmax) + sum_exp_x = 0.0 + for i in cutlass.range_constexpr(num_iter): + sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp( + max_x_single_warp[i] - max_x_final, + fastmath=True, + ) + sum_exp_x = warp_reduce(sum_exp_x, operator.add) + if cutlass.const_expr(return_exp_x): + exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True) + max_x = max_x_final + return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None) + + +# ------------------------- +# Copy helpers (minimal subset of quack.copy_utils) +# ------------------------- + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], + num_copy_elems: int, + is_async: bool = False, + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + from cutlass.cute.nvgpu import cpasync + + # cp.async is limited to 128b per op; synchronous vectorized copies can go wider. + max_bits = const_expr(128 if is_async else 256) + num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width)) + # Match Quack's default cp.async cache policy (leave cache_mode unspecified). + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom( + copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip + ) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom( + src.element_type, num_copy_elems, is_async, loc=loc, ip=ip + ) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +# ------------------------- +# Reduction base (from quack.reduction_base) +# ------------------------- + + +class ReductionBase: + def __init__( + self, + dtype: Type[cutlass.Numeric], + N: int, + stage: int, + reduction_dtype: Type[cutlass.Numeric] = cutlass.Float32, + ): + self.dtype = dtype + self.N = N + self.stage = stage + self.reduction_dtype = reduction_dtype + + def _calculate_threads_per_row(self) -> int: + raise NotImplementedError() + + def _set_cluster_n(self) -> None: + self.cluster_n = 1 + + def _get_num_threads(self) -> int: + return 128 if self.N <= 16384 else 256 + + def _get_tv_layout( + self, num_copy_bits: int = 128 + ) -> Tuple[cute.Shape, cute.Layout]: + """Return (tiler_mn, tv_layout) for SM100 reduction kernels. + + This intentionally mirrors Quack's `ReductionBase._get_tiled_copy(...)`: + - `tiler_mn` spans the full N range for the CTA, including any "K-loop" + repeats (`num_blocks_N`). + - `tv_layout` is the *tiled* thread/value layout used by CuTe's copy + partitioning (does **not** bake in `num_blocks_N`), matching + `quack.copy_utils.tiled_copy_2d(...).layout_tv_tiled`. + """ + if num_copy_bits > 128: + raise ValueError( + f"num_copy_bits={num_copy_bits} exceeds 128b; Quack-style SM100 reduction " + "tiling assumes <=128b vectorization (cp.async and common CopyAtoms)." + ) + vecsize = num_copy_bits // self.dtype.width + assert self.N % vecsize == 0, ( + f"Input N {self.N} is not divisible by vector size {vecsize}" + ) + num_threads = self._get_num_threads() + assert num_threads % cute.arch.WARP_SIZE == 0 + + threads_per_row = self._calculate_threads_per_row() + self._set_cluster_n() + num_blocks_N = cute.ceil_div( + self.N // vecsize, threads_per_row * self.cluster_n + ) + cols_per_block = num_threads // threads_per_row + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + + # Construct the same tv layout that Quack gets from `tiled_copy_2d(...).layout_tv_tiled`. + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=num_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (cols_per_block, threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, vecsize)) + tv_layout = cute.make_tiled_copy_tv( + copy_atom, thr_layout, val_layout + ).layout_tv_tiled + return tiler_mn, tv_layout + + def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: + # Mirror the allocation order used by the SM100 reduction kernels: + # 1) sX (byte_alignment=16) + # 2) reduction_buffer (byte_alignment=8) + # 3) mbar_ptr (Int64, 8B) + # + # CuTeDSL's SmemAllocator may insert padding between allocations to satisfy + # alignment. Be conservative and round up offsets accordingly so we never + # under-allocate dynamic shared memory. + + def _align_up(x: int, align: int) -> int: + return ((x + align - 1) // align) * align + + sx_bytes = int(cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))) + reduction_bytes = int( + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + ) + mbar_bytes = int(self.stage * (cutlass.Int64.width // 8)) + + offset = _align_up(sx_bytes, 16) + offset = _align_up(offset, 8) + reduction_bytes + offset = _align_up(offset, 8) + mbar_bytes + return int(offset) + + def _get_reduction_buffer_layout( + self, tv_layout: cute.Layout, cluster_n: int + ) -> cute.Layout: + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + warps_per_row = ( + num_warps + if cutlass.const_expr(cute.rank(tv_layout.shape[0]) == 1) + else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) + ) + return cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), + order=(1, 0, 2), + ) + + def _allocate_reduction_buffer_and_mbar( + self, + smem: cutlass.utils.SmemAllocator, + tv_layout: cute.Layout, + is_persistent: bool = False, + ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]: + reduction_buffer = smem.allocate_tensor( + self.reduction_dtype, + self._get_reduction_buffer_layout(tv_layout, self.cluster_n), + byte_alignment=8, + ) + if cutlass.const_expr(self.cluster_n > 1): + mbar_ptr = smem.allocate_array( + cutlass.Int64, + num_elems=self.stage if not is_persistent else self.stage * 2, + ) + else: + mbar_ptr = None + return reduction_buffer, mbar_ptr + + @cute.jit + def _initialize_cluster( + self, + tidx: cutlass.Int32, + mbar_ptr: Optional[cute.Pointer], + num_warps: int, + is_persistent: bool = False, + ) -> None: + if cutlass.const_expr(self.cluster_n > 1 and mbar_ptr is not None): + if tidx < self.stage: + cute.arch.mbarrier_init(mbar_ptr + tidx, 1) + if cutlass.const_expr(is_persistent): + cute.arch.mbarrier_init( + mbar_ptr + self.stage + tidx, + num_warps * self.cluster_n, + ) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + +# ------------------------- +# RMSNorm backward base (from quack.rmsnorm.RMSNormBackward) +# ------------------------- + + +class RMSNormBackward(ReductionBase): + def __init__(self, dtype: cutlass.Numeric, N: int): + # 2 stages for double buffering when computing mean of x_hat * wdy + super().__init__(dtype, N, stage=2, reduction_dtype=Float32) + self.reload_wdy = None if N <= 16 * 1024 else "smem" + # Optional optimization: atomically accumulate mdW into a single (N,) + # buffer instead of writing an (sm_count, N) partial buffer + torch.sum. + self.atomic_dw = False + if self.N > 128 * 1024 and self.dtype.width >= 32: + raise ValueError( + "RMSNormBackward does not support N > 128k with dtype >= 32 bits" + ) + + def _get_num_threads(self) -> int: + return 128 if self.N <= 4096 else 256 + + def _calculate_threads_per_row(self) -> int: + N = self.N + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else ( + 32 + if N <= 256 + else (64 if N <= 512 else (128 if N <= 4096 else 256)) + ) + ) + ) + + def _set_cluster_n(self) -> None: + N = self.N + cluster_n = ( + 1 + if N <= 8 * 1024 + else ( + 2 + if N <= 16 * 1024 + else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + def _smem_size_in_bytes(self, tiler_mn, num_warps: int, do_dtype=None) -> int: + if do_dtype is None: + do_dtype = self.dtype + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 + + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2 + + self.stage + * num_warps + * self.cluster_n + * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) * 2 + ) + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + sm_count: Int32, + stream: cuda.CUstream, + ): + semistatic_shape = (*mX.shape[:-1], self.N) + + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=128 // t.element_type.width), + t.stride[1], + ) + + mX, mdO, mdResO, mdX, mdRes = [ + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) + if const_expr(t is not None) + else None + for t in (mX, mdO, mdResO, mdX, mdRes) + ] + self._set_cluster_n() + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mW.element_type.width if mW is not None else 0, + mdO.element_type.width, + mdX.element_type.width, + mdResO.element_type.width if mdResO is not None else 0, + mdRes.element_type.width if mdRes is not None else 0, + ) + ) + # Quack-style policy: cap the *largest* dtype to 128b, then scale the + # activation copy width down proportionally (e.g. fp16 + fp32-weight + # => 64b activation vectors so the fp32 path stays at 128b). + num_copy_bits = const_expr(128 // largest_dtype_width * mX.element_type.width) + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=int(num_copy_bits)) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + if const_expr(mW is not None): + mW_expanded_layout = cute.prepend( + mW.layout, + cute.make_layout((tiler_mn[0],), stride=(0,)), + ) + mW = cute.make_tensor(mW.iterator, mW_expanded_layout) + + num_blocks = sm_count + kernel = ( + self.kernel( + mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes) + ) + kernel.launch( + grid=[num_blocks, self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None, + smem=self._smem_size_in_bytes( + tiler_mn, num_warps, do_dtype=mdO.element_type + ), + stream=stream, + ) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx_start, _, _ = cute.arch.block_idx() + gdim, _, _ = cute.arch.grid_dim() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + shape = mX.shape + M = shape[0] + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + + idX = cute.make_identity_tensor(shape) + + smem = cutlass.utils.SmemAllocator() + smem_layout = cute.make_ordered_layout( + (tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2) + ) + sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16) + sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, + tv_layout, + is_persistent=True, + ) + if const_expr(mbar_ptr is not None): + mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2 + else: + mbar_full_ptr, mbar_empty_ptr = None, None + + num_copy_elems_X = ( + tv_layout.shape[1] + if cutlass.const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if cutlass.const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + copy_atom_load_X = get_copy_atom( + mX.element_type, num_copy_elems_X, is_async=False + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_X)) + thr_copy_X = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout + ).get_slice(tidx) + copy_fn = partial(copy, num_copy_elems=num_copy_elems_X) + + gX, gdO, gdResO, gdX, gdRes, cX = [ + cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None + for mT in (mX, mdO, mdResO, mdX, mdRes, idX) + ] + gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None + gdW, gdB = [ + cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y)) + if const_expr(mT is not None) + else None + for mT in (mdW, mdB) + ] + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXgdO = thr_copy_X.partition_S(gdO) + tXsdO = thr_copy_X.partition_D(sdO) + tXgdX = thr_copy_X.partition_D(gdX) + if const_expr(mdResO is not None): + tXgdResO = thr_copy_X.partition_S(gdResO) + if const_expr(mdRes is not None): + tXgdRes = thr_copy_X.partition_D(gdRes) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None] + + tXrX, tXrdO, tXrdX = [ + cute.make_fragment_like(thr[None, None, None, 0]) + for thr in (tXgX, tXgdO, tXgdX) + ] + tXrdResO = None + if const_expr(mdResO is not None): + tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0]) + tXrdRes = None + if const_expr(mdRes is not None): + tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0]) + + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1]) + if not is_even_N + else None + ) + + tXgdW, tXrdW = None, None + tXgdB, tXrdB = None, None + if const_expr(mdW is not None): + tXgdW = thr_copy_X.partition_S(gdW) + tXrdW = cute.make_fragment_like(tXgdW, Float32) + if const_expr(mdB is not None): + tXgdB = thr_copy_X.partition_S(gdB) + tXrdB = cute.make_fragment_like(tXgdB, Float32) + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + + self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True) + + tXrW = None + if const_expr(mW is not None): + tXgW = thr_copy_X.partition_S(gW) + tXrW = cute.make_fragment_like(tXgW) + if not is_even_N: + tXrW.fill(0.0) + copy_fn(tXgW, tXrW, pred=tXpX) + + row = tXcX[None, None, None, bidx_start][0][0] + if row < M: + tXgX_cur = coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0] + tXgdO_cur = coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0] + copy_fn(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True) + copy_fn(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True) + elif tiler_mn[0] > 1: + fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero) + fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero) + cute.arch.cp_async_commit_group() + + if const_expr(self.cluster_n > 1): + cute.arch.cluster_wait() + + if const_expr(mdW is not None): + tXrdW.fill(0.0) + if const_expr(mdB is not None): + tXrdB.fill(0.0) + stage = Int32(0) + producer_phase = Int32(1) + consumer_phase = Int32(0) + for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim): + row = tXcX[None, None, None, bidx][0][0] + if row + gdim * tiler_mn[0] < M: + tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[ + None, None, None, 0 + ] + tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[ + None, None, None, 0 + ] + copy_fn( + tXgX_cur, + tXsX[None, None, None, stage ^ 1], + pred=tXpX, + is_async=True, + ) + copy_fn( + tXgdO_cur, + tXsdO[None, None, None, stage ^ 1], + pred=tXpX, + is_async=True, + ) + elif tiler_mn[0] > 1: + fill_oob( + tXsX[None, None, None, stage ^ 1], + None, + fill_value=mX.element_type.zero, + ) + fill_oob( + tXsdO[None, None, None, stage ^ 1], + None, + fill_value=mdO.element_type.zero, + ) + cute.arch.cp_async_commit_group() + rstd_val = cutlass.Float.zero + if row < M or tiler_mn[0] == 1: + rstd_val = mRstd[row] + if const_expr(mdResO is not None): + tXgdResO_cur = coord_offset_i64(bidx, tXgdResO, dim=3)[ + None, None, None, 0 + ] + if row < M or tiler_mn[0] == 1: + copy_fn(tXgdResO_cur, tXrdResO, pred=tXpX) + elif tiler_mn[0] > 1: + tXrdResO.fill(0.0) + cute.arch.cp_async_wait_group(1) + cute.autovec_copy(tXsX[None, None, None, stage], tXrX) + x = tXrX.load().to(cute.Float32) + cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + dout = tXrdO.load().to(cute.Float32) + x_hat = x * rstd_val + wdy = dout + if const_expr(mW is not None): + wdy *= tXrW.load().to(Float32) + if const_expr(self.cluster_n > 1): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + mean_xhat_wdy = ( + row_reduce_add( + x_hat * wdy, + threads_per_row, + reduction_buffer[None, None, stage], + (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None), + phase=consumer_phase, + init_val=0.0, + ) + / shape[1] + ) + + if const_expr(self.cluster_n > 1): + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.sync_warp() + lane_idx = cute.arch.lane_idx() + if lane_idx < self.cluster_n: + cute.arch.mbarrier_arrive( + mbar_empty_ptr + stage, + peer_cta_rank_in_cluster=lane_idx, + ) + + if const_expr(self.reload_wdy == "smem"): + cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + dout = tXrdO.load().to(cute.Float32) + wdy = dout + if const_expr(mW is not None): + wdy *= tXrW.load().to(Float32) + + dx = (wdy - x_hat * mean_xhat_wdy) * rstd_val + if const_expr(mdResO is not None): + dx += tXrdResO.load().to(cute.Float32) + tXrdX.store(dx.to(tXrdX.element_type)) + if row < M or tiler_mn[0] == 1: + tXgdX_cur = coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0] + copy_fn(tXrdX, tXgdX_cur, pred=tXpX) + if const_expr(mdRes is not None): + tXrdRes.store(dx.to(tXrdRes.element_type)) + tXgdRes_cur = coord_offset_i64(bidx, tXgdRes, dim=3)[ + None, None, None, 0 + ] + if row < M or tiler_mn[0] == 1: + copy_fn(tXrdRes, tXgdRes_cur, pred=tXpX) + if const_expr(mdW is not None): + tXrdW.store(tXrdW.load() + dout * x_hat) + if const_expr(mdB is not None): + tXrdB.store(tXrdB.load() + dout) + + stage ^= 1 + if stage == 0: + consumer_phase ^= 1 + producer_phase ^= 1 + + if const_expr(tiler_mn[0] > 1): + if const_expr(mdW is not None): + sdW = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + tXsdW = thr_copy_X.partition_D(sdW) + cute.arch.barrier() + row0 = tXcX[None, None, None, 0][0][0] + if row0 > 0: + cute.autovec_copy(tXrdW, tXsdW) + cute.arch.barrier() + if row0 == 0: + for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])): + tXrdW_other = cute.make_fragment_like(tXrdW) + tXsdW_other = cute.make_tensor( + tXsdW.iterator + i * sdW.stride[0], + tXsdW.layout, + ) + cute.autovec_copy(tXsdW_other, tXrdW_other) + tXrdW.store(tXrdW.load() + tXrdW_other.load()) + if const_expr(self.atomic_dw): + atomic_add_tensor_f32(tXrdW, tXgdW, pred=tXpX) + else: + copy_fn(tXrdW, tXgdW, pred=tXpX) + cute.arch.barrier() + if const_expr(mdB is not None): + sdB = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + tXsdB = thr_copy_X.partition_D(sdB) + cute.arch.barrier() + row0 = tXcX[None, None, None, 0][0][0] + if row0 > 0: + cute.autovec_copy(tXrdB, tXsdB) + cute.arch.barrier() + if row0 == 0: + for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])): + tXrdB_other = cute.make_fragment_like(tXrdB) + tXsdB_other = cute.make_tensor( + tXsdB.iterator + i * sdB.stride[0], + tXsdB.layout, + ) + cute.autovec_copy(tXsdB_other, tXrdB_other) + tXrdB.store(tXrdB.load() + tXrdB_other.load()) + copy_fn(tXrdB, tXgdB, pred=tXpX) + else: + if const_expr(mdW is not None): + if const_expr(self.atomic_dw): + atomic_add_tensor_f32(tXrdW, tXgdW, pred=tXpX) + else: + copy_fn(tXrdW, tXgdW, pred=tXpX) + if const_expr(mdB is not None): + copy_fn(tXrdB, tXgdB, pred=tXpX) + + if const_expr(self.cluster_n > 1): + stage ^= 1 + if stage == 0: + producer_phase ^= 1 + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + self._kernel_impl( + mX, + mW, + mdO, + mdResO, + mRstd, + mdX, + mdW, + mdB, + mdRes, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + ): + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mdO.element_type.width, + mdX.element_type.width, + mdResO.element_type.width if mdResO is not None else 0, + mdRes.element_type.width if mdRes is not None else 0, + ) + ) + tiler_mn, tv_layout = self._get_tv_layout( + num_copy_bits=128 // largest_dtype_width * mX.element_type.width + ) + self._kernel_impl( + mX, + mW, + mdO, + mdResO, + mRstd, + mdX, + mdW, + mdB, + mdRes, + tv_layout, + tiler_mn, + ) + + +# ------------------------- +# SM count helper (from quack.rmsnorm._get_sm_count) +# ------------------------- + + +def get_sm_count( + N: int, + device: torch.device, + M: Optional[int] = None, + dtype: Optional[torch.dtype] = None, +) -> int: + """ + SM count heuristic for reduction-style kernels. + + This starts from Quack's _get_sm_count policy and layers on SM100 / + DSv3-specific tuning so that: + - For DSv3-style shapes (large-M, N in {6144, 8192}, fp16/bf16), + sm_count is reduced for very large M to cut down the number of + dw_partial/db_partial rows that ever hit HBM. + - For Quack-suite hidden=4096, small-M shapes, sm_count is modestly + increased to improve SM occupancy, matching the existing SM100 + tuning used by both RMSNorm and LayerNorm. + """ + num_sms = get_num_sms(device) + + sm_count_multiple = ( + 16 + if N <= 256 + else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) + ) + sm_count = num_sms + if N <= 8192: + sm_count = sm_count * sm_count_multiple + elif N <= 16384: + sm_count = sm_count // 2 + else: + sm_count = sm_count * 2 + + # Quack-suite tuning: for small-M, hidden=4096 shapes (M<=8192) and + # 16-bit dtypes, increase sm_count to improve occupancy. This mirrors + # the existing SM100 RMSNorm/LayerNorm heuristics. + if ( + dtype in (torch.float16, torch.bfloat16) + and M is not None + and M <= 8192 + and N == 4096 + ): + sm_count = min(sm_count * 2, num_sms * 4) + + return sm_count diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py new file mode 100644 index 0000000..a96a4c7 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Torch custom ops wrapping Oink's Blackwell RMSNorm kernels. + +These ops are designed to be: +- Architecture-aware (use CuTeDSL SM100 kernels when available, fall back + to a safe reference elsewhere). +- Layout-preserving for 2D row-major inputs, including padded MLA-style + layouts where stride(0) > N and stride(1) == 1. +- torch.compile-friendly via proper fake implementations that mirror + runtime shapes and strides. + +Public ops (Python signatures): + + torch.ops.oink.rmsnorm(x: Tensor, weight: Tensor, eps: float) -> Tensor + Functional RMSNorm. Returns a new tensor with the same shape and + stride as x when using the fast CuTeDSL path. + + torch.ops.oink.fused_add_rms_norm( + x: Tensor, residual: Tensor, weight: Tensor, eps: float + ) -> None + In-place fused residual-add + RMSNorm matching vLLM semantics: + residual = x + residual (stored into `residual`) + x = RMSNorm(residual, w) (stored into `x`) + Mutates `x` and `residual` in-place and returns None. +""" + +from __future__ import annotations + +import importlib +import threading + +import torch +from torch.library import custom_op + +_RMSNORM_MOD: object | None = None +_RMSNORM_MOD_LOCK = threading.Lock() + + +def _get_rmsnorm_mod(): + """Lazy import to keep plugin registration lightweight. + + Importing the CuTeDSL kernel stack can be expensive and may require a CUDA + context. We defer it until the first actual execution of the custom op. + """ + global _RMSNORM_MOD + + cached = _RMSNORM_MOD + if cached is not None: + return cached + + with _RMSNORM_MOD_LOCK: + if _RMSNORM_MOD is None: + _RMSNORM_MOD = importlib.import_module("kernelagent_oink.blackwell.rmsnorm") + return _RMSNORM_MOD + + +def _get_sm(device: torch.device | None = None) -> int: + """Return SM version as an int (e.g., 100 for SM100 / Blackwell).""" + if device is None: + device = torch.device("cuda") + major, minor = torch.cuda.get_device_capability(device) + return 10 * major + minor + + +# +# RMSNorm (functional) +# + + +@custom_op("oink::rmsnorm", mutates_args=()) +def oink_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + """ + Functional RMSNorm entrypoint. + + This op is model-agnostic. It expects a 2D [M, N] view of the input + where the last dimension is contiguous (stride(1) == 1). The leading + dimension stride(0) may be larger than N (padded-row layouts), and + will be preserved on the fast CuTeDSL path. + + On SM100 (and newer), this dispatches to the tuned CuTeDSL Blackwell + RMSNorm kernel in rmsnorm.rmsnorm_forward, which in turn selects the + best internal schedule (including DSv3-specific stage-2 kernels where + applicable) and preserves the input's 2D stride when using the + pointer-based path. + + On older architectures it falls back to a safe PyTorch reference + implementation for correctness. + """ + assert x.is_cuda, "oink::rmsnorm requires CUDA tensors" + assert x.dim() == 2, "oink::rmsnorm expects a 2D [M, N] tensor view" + assert weight.dim() == 1, "weight must be 1D [N]" + + sm = _get_sm(x.device) + if sm >= 100: + # Use the tuned CuTeDSL SM100 kernel. The public API already + # contains all necessary gating and layout checks internally. + _rms = _get_rmsnorm_mod() + y, _rstd, _res = _rms.rmsnorm_forward( + x, + weight=weight, + bias=None, + residual=None, + eps=eps, + store_rstd=False, + ) + return y + + # Fallback: reference implementation (correctness-first). + _rms = _get_rmsnorm_mod() + return _rms.rmsnorm_ref( + x, + w=weight, + b=None, + residual=None, + eps=eps, + ) + + +@oink_rmsnorm.register_fake +def oink_rmsnorm_fake( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + """ + Fake (meta) implementation for oink::rmsnorm. + + We must preserve x's logical layout (shape + stride) so that Inductor's + CUDA graph capture sees the same stride contract as the real kernel. + """ + # x is a FakeTensor here; x.shape/x.stride()/x.device/x.dtype are defined. + return torch.empty_strided( + x.shape, + x.stride(), + device=x.device, + dtype=x.dtype, + ) + + +# +# Fused residual-add + RMSNorm (in-place, vLLM semantics) +# + + +@custom_op("oink::fused_add_rms_norm", mutates_args=("x", "residual")) +def oink_fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> None: + """ + In-place fused residual-add + RMSNorm: + + residual <- x + residual + x <- RMSNorm(residual, weight, eps) + + Returns: + None (mutates `x` and `residual` in-place). + """ + assert x.is_cuda and residual.is_cuda, ( + "oink::fused_add_rms_norm requires CUDA tensors" + ) + assert x.shape == residual.shape, "x and residual must have the same shape" + assert x.dtype == residual.dtype, "x and residual must have the same dtype" + assert weight.dim() == 1, "weight must be 1D [N]" + + sm = _get_sm(x.device) + if sm >= 100: + _rms = _get_rmsnorm_mod() + # Prefer the lowest-overhead in-place entrypoint (returns None). + if hasattr(_rms, "fused_add_rmsnorm_inplace_"): + _rms.fused_add_rmsnorm_inplace_( # type: ignore[misc] + x, + residual, + weight, + eps=eps, + ) + return None + # Backward-compatible wrapper (returns (x, residual)). + if hasattr(_rms, "fused_add_rmsnorm_forward_inplace"): + _rms.fused_add_rmsnorm_forward_inplace( # type: ignore[misc] + x, + residual, + weight, + eps=eps, + ) + return None + + # Extremely defensive fallback if the Oink module doesn't provide + # the in-place entrypoint. + y, z = _rms.fused_add_rmsnorm_forward(x, residual, weight, eps=eps) + x.copy_(y) + residual.copy_(z) + return None + + # Non-SM100 fallback: keep semantics in-place (correctness-first). + residual.add_(x) + _rms = _get_rmsnorm_mod() + y = _rms.rmsnorm_ref(residual, w=weight, b=None, residual=None, eps=eps) + x.copy_(y) + return None + + +@oink_fused_add_rms_norm.register_fake +def oink_fused_add_rms_norm_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> None: + """ + Fake (meta) implementation for oink::fused_add_rms_norm. + + Because this op mutates its inputs in-place, the outputs alias the input + buffers and therefore have the same shapes and strides. + """ + return None + + +__all__ = [ + "oink_rmsnorm", + "oink_fused_add_rms_norm", +] diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py new file mode 100644 index 0000000..252df6a --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -0,0 +1,4418 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RMSNorm kernel for SM100 (Blackwell) in CuteDSL. + +This implementation targets Blackwell with: +- A stride-preserving pointer path for padded-row layouts (e.g. MLA stride0> N). +- A one-pass fused-add RMSNorm schedule for bf16/fp16 (DSv3 N=7168) that keeps + `x + residual` in registers (avoids re-reading gmem) and uses FP32 accumulation. +- Optional experimental schedule knobs (env vars) to explore copy widths and + stage-2 cp.async variants. + +Note: This file expects the local CuTeDSL (cutlass) and SM100 helper modules +to be available in the Python environment (e.g., `nvidia-cutlass-dsl` and +`cuda-python`). It is shipped as part of the KernelAgent Oink vLLM plugin. +""" + +from __future__ import annotations + +import ctypes +import importlib.metadata +import os +import re +import subprocess +import sys +import threading +from typing import Optional, Tuple + +_HERE = os.path.dirname(__file__) + +# CuTeDSL caches generated MLIR into a tempdir under a global default +# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across +# `nvidia-cutlass-dsl` versions (e.g. 4.3.2 vs 4.3.4), and cross-version cache +# sharing causes noisy "invalid section ID" warnings (and disables cache reuse). +# +# If the user has not pinned `CUTE_DSL_CACHE_DIR`, isolate by version so multiple +# CuTeDSL envs can coexist on the same machine without stepping on each other. +if "CUTE_DSL_CACHE_DIR" not in os.environ: + try: + _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + _dsl_ver = "unknown" + _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) + _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + _tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" + ) + +try: + import cutlass # type: ignore # noqa: F401 +except Exception as e: + raise ImportError( + "kernelagent_oink.blackwell.rmsnorm requires CuTeDSL's Python package " + "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." + ) from e + +import torch # noqa: E402 +from torch import Tensor # noqa: E402 + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python # noqa: E402 + +import cutlass # noqa: E402 +import cutlass.cute as cute # noqa: E402 +from cutlass import Float32, Int32, const_expr # noqa: E402 +from cutlass.cute import runtime as rt # noqa: E402 +from cutlass.cute.runtime import from_dlpack # noqa: E402 + +# Simple compile cache declared early so direct execution works +_PTR_COMPILE_CACHE = {} + +# Thread-local cache for the fast-launch path. We keep per-thread packed args and +# pointer/scalar storage so concurrent callers don't race on in-place updates. +_PTR_FAST_LAUNCH_TLS = threading.local() + + +# Cache a (1, sm_count) fp32 ones row used for GEMM-based dw/db partial reductions. +# +# On SM100, `dw_partial.sum(dim=0)` can be a double-digit microsecond tail for +# Quack-suite small shapes (e.g. M=8192, N=4096). A cached GEMM-based reduction +# is consistently faster and avoids per-call allocation overhead. +_DW_REDUCE_ONES_CACHE: dict[tuple[int, int], Tensor] = {} + + +def _get_dw_reduce_ones(device_index: int, sm_count: int) -> Tensor: + key = (int(device_index), int(sm_count)) + ones = _DW_REDUCE_ONES_CACHE.get(key) + if ones is None or ones.shape != (1, sm_count) or ones.device.index != device_index: + ones = torch.ones( + (1, sm_count), + device=torch.device("cuda", device_index), + dtype=torch.float32, + ) + _DW_REDUCE_ONES_CACHE[key] = ones + return ones + + +def _reduce_partial_sum_fp32(partial: Tensor, *, device_index: int) -> Tensor: + """Reduce a (sm_count, N) fp32 partial buffer into an (N,) fp32 result.""" + assert partial.dtype is torch.float32 + assert partial.dim() == 2 + ones = _get_dw_reduce_ones(device_index, int(partial.shape[0])) + return torch.mm(ones, partial).squeeze(0) + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name) + if val is None: + return default + return val.strip().lower() not in {"0", "false", "no", "off", ""} + + +# Fast-launch uses a few private-ish CuTeDSL internals (packed args plumbing and +# runtime pointer descriptors). Keep it enabled by default for our pinned CuTeDSL +# environment, but allow disabling it via env var and auto-disable it if those +# internals are not present in a future upgrade. +_ENABLE_FAST_LAUNCH = _env_flag("OINK_CUTEDSL_FAST_LAUNCH", default=True) +_FAST_LAUNCH_SUPPORTED = True + +# Fused-add RMSNorm schedule knobs (read once at import time; set env vars before +# importing this module if you want to override). +_DIRECT_GMEM_POLICY = ( + os.environ.get("OINK_RMSNORM_DIRECT_GMEM", "auto").strip().lower() or "auto" +) +_COPY_BITS_POLICY = ( + os.environ.get("OINK_RMSNORM_COPY_BITS", "auto").strip().lower() or "auto" +) +_ENABLE_CLUSTER_ILP = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP", default=False) +_ENABLE_CLUSTER_ILP_UNSAFE = _env_flag( + "OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE", default=False +) +_ENABLE_TPR256 = _env_flag("OINK_RMSNORM_ENABLE_TPR256", default=False) +_ENABLE_STAGE2 = _env_flag("OINK_RMSNORM_ENABLE_STAGE2", default=False) + +# Forward dispatch control: +# - Default behavior: use the pointer-based path when safe, otherwise fall back +# to the stage-2 module (then the torch reference). +# - If you want to force stage-2 even when the pointer path is available (for +# experimentation / A-B testing), set this env var **before** importing this +# module. +_FORCE_RMSNORM_STAGE2_FWD = _env_flag( + "KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2", default=False +) + +# CuTeDSL stability probe for the experimental cluster_n>1 + direct-GMEM schedule. +# +# Some CuTeDSL builds segfault during JIT compilation when combining: +# - cluster launches (cluster_n>1) and +# - direct-GMEM loads/stores (no staging SMEM tiles). +# +# We keep the schedule gated behind `OINK_RMSNORM_ENABLE_CLUSTER_ILP=1` + +# `OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1`, and additionally run a one-time +# out-of-process compile probe so we can safely fall back to the staged SMEM +# path instead of crashing the parent process. +# +# This is (currently) sensitive to the vector width: we have observed +# reproducible segfaults for the 256b universal-copy path, while the 128b path +# can succeed. Cache the maximum supported copy width (0 = unsupported). +_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS: Optional[int] = None +_CLUSTER_DIRECT_GMEM_PROBE_LOCK = threading.Lock() +_CLUSTER_DIRECT_GMEM_PROBE_WARNED = False + + +def _probe_cluster_direct_gmem_max_copy_bits() -> int: + global _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS + global _CLUSTER_DIRECT_GMEM_PROBE_WARNED + + override = os.environ.get("OINK_RMSNORM_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS") + if override is not None and override.strip() != "": + try: + value = int(override) + except ValueError: + value = 0 + value = 256 if value >= 256 else 128 if value >= 128 else 0 + _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = value + return value + + if _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS is not None: + return _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS + + with _CLUSTER_DIRECT_GMEM_PROBE_LOCK: + if _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS is not None: + return _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS + + script_template = r""" +import os + +os.environ["OINK_CUTEDSL_FAST_LAUNCH"] = "0" + +import cutlass +import cutlass.cute as cute +import cuda.bindings.driver as cuda +from cutlass import Float32, Int32 +from cutlass.cute import runtime as rt + +from kernelagent_oink.blackwell import rmsnorm + +N = 7168 +dtype = cutlass.BFloat16 + +copy_bits = int(os.environ["OINK_PROBE_COPY_BITS"]) +assumed_align = int(os.environ["OINK_PROBE_ASSUMED_ALIGN"]) + +op = rmsnorm.RMSNormSM100( + N, + dtype, + stage=1, + copy_bits=copy_bits, + use_async=False, + direct_gmem=True, +) +op._cluster_n_override = 2 # 2 CTAs per row + +ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) +ptr_res = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) +ptr_w = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) + +_ = cute.compile( + op.launch_from_ptrs_fused_add_inplace, + ptr_x, + ptr_w, + ptr_res, + Int32(4096), + Int32(N), + Int32(N), + cuda.CUstream(0), + Float32(1e-6), +) +print(f"ok {copy_bits}") +""" + + env = os.environ.copy() + # The probe runs in a fresh subprocess, so it won't inherit any + # benchmark-harness sys.path tweaks. Ensure the in-tree Oink source is + # importable so `import kernelagent_oink...` works reliably. + oink_src = os.path.abspath(os.path.join(_HERE, "..", "..")) + if os.path.isdir(oink_src): + py_path = env.get("PYTHONPATH") + env["PYTHONPATH"] = oink_src + (os.pathsep + py_path if py_path else "") + env["PYTHONNOUSERSITE"] = "1" + + def run_probe(copy_bits: int, assumed_align: int): + probe_env = env.copy() + probe_env["OINK_PROBE_COPY_BITS"] = str(copy_bits) + probe_env["OINK_PROBE_ASSUMED_ALIGN"] = str(assumed_align) + return subprocess.run( + [sys.executable, "-c", script_template], + env=probe_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=120.0, + ) + + proc_256 = None + proc_128 = None + try: + proc_256 = run_probe(256, 32) + if proc_256.returncode == 0: + max_bits = 256 + else: + proc_128 = run_probe(128, 16) + max_bits = 128 if proc_128.returncode == 0 else 0 + except Exception: + max_bits = 0 + + if not _CLUSTER_DIRECT_GMEM_PROBE_WARNED and max_bits != 256: + _CLUSTER_DIRECT_GMEM_PROBE_WARNED = True + if max_bits == 128: + print( + "Oink: cluster_n>1 + direct_gmem 256b compile probe failed; " + "using 128b copies for the cluster ILP schedule.", + file=sys.stderr, + ) + if proc_256 is not None and proc_256.stderr: + tail = "\n".join(proc_256.stderr.splitlines()[-12:]) + print(f"Oink: probe stderr tail:\n{tail}", file=sys.stderr) + else: + rc = ( + proc_128.returncode + if proc_128 is not None + else (proc_256.returncode if proc_256 is not None else "unknown") + ) + print( + "Oink: cluster_n>1 + direct_gmem compile probe failed; " + f"falling back to staged SMEM path (returncode={rc}).", + file=sys.stderr, + ) + failing_proc = proc_128 if proc_128 is not None else proc_256 + if failing_proc is not None and failing_proc.stderr: + tail = "\n".join(failing_proc.stderr.splitlines()[-12:]) + print(f"Oink: probe stderr tail:\n{tail}", file=sys.stderr) + + _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = max_bits + return max_bits + + +def _parse_version_tuple(version: str) -> Tuple[int, int, int]: + parts = version.split(".") + nums: list[int] = [] + for part in parts[:3]: + match = re.match(r"^(\d+)", part) + nums.append(int(match.group(1)) if match is not None else 0) + while len(nums) < 3: + nums.append(0) + return nums[0], nums[1], nums[2] + + +def _cutlass_dsl_version() -> Optional[Tuple[int, int, int]]: + try: + return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl")) + except Exception: + return None + + +_CUTLASS_DSL_VERSION = _cutlass_dsl_version() +# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around +# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the +# older signature for 4.3.2, but switch to a 4.3.4-compatible signature when we +# detect 4.3.4+ (or when version detection is unavailable). +_KERNEL_ACCEPTS_LAYOUT_ARGS = ( + _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) +) + +if _ENABLE_CLUSTER_ILP and not _ENABLE_CLUSTER_ILP_UNSAFE: + # We have observed reproducible segfaults in some CuTeDSL builds when using + # cluster launches for this schedule. Require an explicit UNSAFE opt-in to + # avoid accidental crashes. + _ENABLE_CLUSTER_ILP = False + print( + "Oink: OINK_RMSNORM_ENABLE_CLUSTER_ILP requested but disabled by default due to " + "known instability; set OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1 to force-enable.", + file=sys.stderr, + ) + + +def _fast_launch_enabled() -> bool: + return _ENABLE_FAST_LAUNCH and _FAST_LAUNCH_SUPPORTED + + +def _direct_gmem_from_policy(*, default: bool) -> bool: + """Resolve the direct-GMEM schedule flag from the (import-time) policy string.""" + if _DIRECT_GMEM_POLICY in {"0", "false", "no", "off"}: + return False + if _DIRECT_GMEM_POLICY in {"1", "true", "yes", "on"}: + return True + return default + + +def _copy_bits_from_policy(*, default: int, can_use_256: bool) -> int: + """Resolve copy width (in bits) from the (import-time) policy string.""" + if _COPY_BITS_POLICY in {"64"}: + return 64 + if _COPY_BITS_POLICY in {"128"}: + return 128 + if _COPY_BITS_POLICY in {"256"} and can_use_256: + return 256 + return default + + +class _StableI32Arg: + """A stable Int32 runtime arg (avoids per-call Int32().__c_pointers__ allocations).""" + + def __init__(self, value: int): + self._c_value = ctypes.c_int32(int(value)) + self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p) + + def set(self, value: int) -> None: + self._c_value.value = int(value) + + def __c_pointers__(self): + return [self._c_pointer] + + +class _StableF32Arg: + """A stable Float32 runtime arg (avoids per-call Float32().__c_pointers__ allocations).""" + + def __init__(self, value: float): + self._c_value = ctypes.c_float(float(value)) + self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p) + + def set(self, value: float) -> None: + self._c_value.value = float(value) + + def __c_pointers__(self): + return [self._c_pointer] + + +def _tls_fast_launch_cache() -> dict[tuple[object, ...], object]: + cache = getattr(_PTR_FAST_LAUNCH_TLS, "cache", None) + if cache is None: + cache = {} + _PTR_FAST_LAUNCH_TLS.cache = cache + return cache + + +def _set_runtime_ptr(ptr: object, device_ptr: int) -> None: + # Runtime pointer objects cache a `ctypes.c_void_p` descriptor and pass + # its address to the compiled function. Updating `_desc.value` updates + # the device pointer without changing the address of the descriptor. + # + # This relies on internal CuTeDSL runtime pointer fields (`_desc`, `_pointer`, + # etc.). If these internals change in a future CuTeDSL upgrade, callers + # should catch AttributeError and fall back to the regular launch path. + device_ptr = int(device_ptr) + ptr._pointer = device_ptr # type: ignore[attr-defined] + if getattr(ptr, "_c_pointer", None) is None: + ptr.__c_pointers__() # type: ignore[attr-defined] + ptr._desc.value = device_ptr # type: ignore[attr-defined] + + +class _PtrRmsnormFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: Optional[object], + ptr_out: object, + arg_m: _StableI32Arg, + arg_n: _StableI32Arg, + arg_ld: _StableI32Arg, + arg_eps: _StableF32Arg, + stream: cuda.CUstream, + assumed_align: int, + weight_dtype: Optional[type[cutlass.Numeric]], + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_out = ptr_out + self._arg_m = arg_m + self._arg_n = arg_n + self._arg_ld = arg_ld + self._arg_eps = arg_eps + self._stream = stream + self._assumed_align = int(assumed_align) + self._weight_dtype = weight_dtype + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_out_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_eps = float("nan") + + def launch( + self, + *, + x: Tensor, + weight: Optional[Tensor], + out: Tensor, + M: int, + N: int, + ld: int, + eps: float, + ) -> None: + if not _fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps) + return + + x_ptr = x.data_ptr() + if x_ptr != self._last_x_ptr: + try: + _set_runtime_ptr(self._ptr_x, x_ptr) + self._last_x_ptr = x_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps + ) + return + + if self._ptr_w is not None: + w_ptr = weight.data_ptr() # type: ignore[union-attr] + if w_ptr != self._last_w_ptr: + try: + _set_runtime_ptr(self._ptr_w, w_ptr) + self._last_w_ptr = w_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps + ) + return + + out_ptr = out.data_ptr() + if out_ptr != self._last_out_ptr: + try: + _set_runtime_ptr(self._ptr_out, out_ptr) + self._last_out_ptr = out_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if eps != self._last_eps: + self._arg_eps.set(eps) + self._last_eps = eps + + # Clear the error slot before launch (mirrors JitExecutor behavior). + if self._cuda_result is not None: + self._cuda_result.value = 0 + + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + global _FAST_LAUNCH_SUPPORTED + self._use_fast_launch = False + _FAST_LAUNCH_SUPPORTED = False + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Optional[Tensor], + out: Tensor, + M: int, + N: int, + ld: int, + eps: float, + ) -> None: + # If the packed-args or runtime pointer mutation path stops working + # (e.g. due to a CuTeDSL upgrade), fall back to the regular call path. + dtype = TORCH2CUTE_DTYPE[x.dtype] + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + ptr_out = rt.make_ptr( + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + ptr_w = ( + rt.make_ptr( + self._weight_dtype or dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + if weight is not None + else None + ) + self._compiled( + ptr_x, + ptr_w, + None, # ptr_b + None, # ptr_res + ptr_out, + None, # ptr_res_out + None, # ptr_rstd + Int32(M), + Int32(N), + Int32(ld), + self._stream, + Float32(eps), + ) + + +class _PtrFusedAddRmsnormFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: object, + ptr_res: object, + arg_m: _StableI32Arg, + arg_n: _StableI32Arg, + arg_ld_x: _StableI32Arg, + arg_eps: _StableF32Arg, + stream: cuda.CUstream, + assumed_align: int, + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_res = ptr_res + self._arg_m = arg_m + self._arg_n = arg_n + self._arg_ld_x = arg_ld_x + self._arg_eps = arg_eps + self._stream = stream + self._assumed_align = int(assumed_align) + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_res_ptr = -1 + self._last_m = -1 + self._last_ld_x = -1 + self._last_eps = float("nan") + + def launch( + self, + *, + x: Tensor, + weight: Tensor, + residual: Tensor, + M: int, + N: int, + ld_x: int, + eps: float, + ) -> None: + if not _fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps + ) + return + + x_ptr = x.data_ptr() + if x_ptr != self._last_x_ptr: + try: + _set_runtime_ptr(self._ptr_x, x_ptr) + self._last_x_ptr = x_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps + ) + return + + w_ptr = weight.data_ptr() + if w_ptr != self._last_w_ptr: + try: + _set_runtime_ptr(self._ptr_w, w_ptr) + self._last_w_ptr = w_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps + ) + return + + res_ptr = residual.data_ptr() + if res_ptr != self._last_res_ptr: + try: + _set_runtime_ptr(self._ptr_res, res_ptr) + self._last_res_ptr = res_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld_x != self._last_ld_x: + self._arg_ld_x.set(ld_x) + self._last_ld_x = ld_x + if eps != self._last_eps: + self._arg_eps.set(eps) + self._last_eps = eps + + if self._cuda_result is not None: + self._cuda_result.value = 0 + + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + global _FAST_LAUNCH_SUPPORTED + self._use_fast_launch = False + _FAST_LAUNCH_SUPPORTED = False + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Tensor, + residual: Tensor, + M: int, + N: int, + ld_x: int, + eps: float, + ) -> None: + dtype = TORCH2CUTE_DTYPE[x.dtype] + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + ptr_res = rt.make_ptr( + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + ptr_w = rt.make_ptr( + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + self._compiled( + ptr_x, + ptr_w, + ptr_res, + Int32(M), + Int32(N), + Int32(ld_x), + self._stream, + Float32(eps), + ) + + +class _PtrRmsnormBwdFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: Optional[object], + ptr_dout: object, + ptr_rstd: object, + ptr_dx: object, + ptr_dw_partial: Optional[object], + arg_m: _StableI32Arg, + arg_n: _StableI32Arg, + arg_ld: _StableI32Arg, + arg_sm_count: _StableI32Arg, + stream: cuda.CUstream, + assumed_align_x: int, + assumed_align_w: int, + assumed_align_dw: int, + weight_dtype: Optional[type[cutlass.Numeric]], + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_dout = ptr_dout + self._ptr_rstd = ptr_rstd + self._ptr_dx = ptr_dx + self._ptr_dw_partial = ptr_dw_partial + self._arg_m = arg_m + self._arg_n = arg_n + self._arg_ld = arg_ld + self._arg_sm_count = arg_sm_count + self._stream = stream + self._assumed_align_x = int(assumed_align_x) + self._assumed_align_w = int(assumed_align_w) + self._assumed_align_dw = int(assumed_align_dw) + self._weight_dtype = weight_dtype + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_dout_ptr = -1 + self._last_rstd_ptr = -1 + self._last_dx_ptr = -1 + self._last_dw_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_sm_count = -1 + + def launch( + self, + *, + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Optional[Tensor], + M: int, + N: int, + ld: int, + sm_count: int, + ) -> None: + if not _fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + x_ptr = x.data_ptr() + if x_ptr != self._last_x_ptr: + try: + _set_runtime_ptr(self._ptr_x, x_ptr) + self._last_x_ptr = x_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + if self._ptr_w is not None: + w_ptr = weight.data_ptr() # type: ignore[union-attr] + if w_ptr != self._last_w_ptr: + try: + _set_runtime_ptr(self._ptr_w, w_ptr) + self._last_w_ptr = w_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + dout_ptr = dout.data_ptr() + if dout_ptr != self._last_dout_ptr: + try: + _set_runtime_ptr(self._ptr_dout, dout_ptr) + self._last_dout_ptr = dout_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + rstd_ptr = rstd.data_ptr() + if rstd_ptr != self._last_rstd_ptr: + try: + _set_runtime_ptr(self._ptr_rstd, rstd_ptr) + self._last_rstd_ptr = rstd_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + dx_ptr = dx.data_ptr() + if dx_ptr != self._last_dx_ptr: + try: + _set_runtime_ptr(self._ptr_dx, dx_ptr) + self._last_dx_ptr = dx_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + if self._ptr_dw_partial is not None: + dw_ptr = dw_partial.data_ptr() # type: ignore[union-attr] + if dw_ptr != self._last_dw_ptr: + try: + _set_runtime_ptr(self._ptr_dw_partial, dw_ptr) + self._last_dw_ptr = dw_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if sm_count != self._last_sm_count: + self._arg_sm_count.set(sm_count) + self._last_sm_count = sm_count + + if self._cuda_result is not None: + self._cuda_result.value = 0 + + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + global _FAST_LAUNCH_SUPPORTED + self._use_fast_launch = False + _FAST_LAUNCH_SUPPORTED = False + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Optional[Tensor], + M: int, + N: int, + ld: int, + sm_count: int, + ) -> None: + dtype = TORCH2CUTE_DTYPE[x.dtype] + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_dout = rt.make_ptr( + dtype, + dout.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_dx = rt.make_ptr( + dtype, + dx.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + TORCH2CUTE_DTYPE[rstd.dtype], + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_w = ( + rt.make_ptr( + self._weight_dtype or dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_w, + ) + if weight is not None + else None + ) + ptr_dw_partial = ( + rt.make_ptr( + TORCH2CUTE_DTYPE[dw_partial.dtype], + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_dw, + ) + if dw_partial is not None + else None + ) + self._compiled( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw_partial, + Int32(M), + Int32(N), + Int32(ld), + Int32(sm_count), + self._stream, + ) + + +def _get_fast_ptr_rmsnorm_bwd_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + weight_dtype: Optional[type[cutlass.Numeric]], + N: int, + device_index: int, + stream_handle: int, + has_weight: bool, + has_dw_partial: bool, + assumed_align_x: int, + assumed_align_w: int, + assumed_align_dw: int, +) -> Optional[_PtrRmsnormBwdFastLaunch]: + if not _fast_launch_enabled(): + return None + key = ( + "ptr_bwd_fast", + id(compiled), + N, + dtype, + weight_dtype, + device_index, + int(stream_handle), + has_weight, + has_dw_partial, + int(assumed_align_x), + int(assumed_align_w), + int(assumed_align_dw), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + assumed_align_x = int(assumed_align_x) + assumed_align_w = int(assumed_align_w) + assumed_align_dw = int(assumed_align_dw) + + ptr_x = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x + ) + ptr_w = ( + rt.make_ptr( + weight_dtype or dtype, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + if has_weight + else None + ) + ptr_dout = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dx = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x + ) + ptr_dw_partial = ( + rt.make_ptr( + cutlass.Float32, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + if has_dw_partial + else None + ) + + arg_m = _StableI32Arg(0) + arg_n = _StableI32Arg(N) + arg_ld = _StableI32Arg(N) + arg_sm_count = _StableI32Arg(0) + stream = cuda.CUstream(int(stream_handle)) + + executor = compiled.to(device_index) # type: ignore[attr-defined] + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw_partial, + arg_m, + arg_n, + arg_ld, + arg_sm_count, + stream, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + global _FAST_LAUNCH_SUPPORTED + _FAST_LAUNCH_SUPPORTED = False + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw_partial, + arg_m, + arg_n, + arg_ld, + arg_sm_count, + stream, + *adapted_args, + ) + + launcher = _PtrRmsnormBwdFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_dout=ptr_dout, + ptr_rstd=ptr_rstd, + ptr_dx=ptr_dx, + ptr_dw_partial=ptr_dw_partial, + arg_m=arg_m, + arg_n=arg_n, + arg_ld=arg_ld, + arg_sm_count=arg_sm_count, + stream=stream, + assumed_align_x=assumed_align_x, + assumed_align_w=assumed_align_w, + assumed_align_dw=assumed_align_dw, + weight_dtype=weight_dtype if has_weight else None, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + +def _get_fast_ptr_rmsnorm_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + weight_dtype: Optional[type[cutlass.Numeric]] = None, + N: int, + device_index: int, + stream_handle: int, + has_weight: bool, + assumed_align: int = 16, + eps: float, +) -> Optional[_PtrRmsnormFastLaunch]: + if not _fast_launch_enabled(): + return None + # Keyed by the compiled object identity so schedule changes (e.g. copy width, + # async/staged variants, etc.) never alias in the fast-launch cache. + key = ( + "ptr_fast", + id(compiled), + N, + dtype, + weight_dtype, + device_index, + int(stream_handle), + has_weight, + int(assumed_align), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + # Create stable runtime args and pointer descriptors once. + assumed_align = int(assumed_align) + ptr_x = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_out = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_w = ( + rt.make_ptr( + weight_dtype or dtype, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if has_weight + else None + ) + + arg_m = _StableI32Arg(0) + arg_n = _StableI32Arg(N) + arg_ld = _StableI32Arg(N) + arg_eps = _StableF32Arg(eps) + + stream = cuda.CUstream(int(stream_handle)) + + # Create an executor (loads the CUDA library once). + executor = compiled.to(device_index) # type: ignore[attr-defined] + + # Use generate_execution_args once to build the packed args array, and keep + # any adapted args alive for the lifetime of the cache entry. + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + None, # ptr_b + None, # ptr_res + ptr_out, + None, # ptr_res_out + None, # ptr_rstd + arg_m, + arg_n, + arg_ld, + stream, + arg_eps, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + global _FAST_LAUNCH_SUPPORTED + _FAST_LAUNCH_SUPPORTED = False + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_out, + arg_m, + arg_n, + arg_ld, + arg_eps, + stream, + *adapted_args, + ) + + launcher = _PtrRmsnormFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_out=ptr_out, + arg_m=arg_m, + arg_n=arg_n, + arg_ld=arg_ld, + arg_eps=arg_eps, + stream=stream, + assumed_align=assumed_align, + weight_dtype=weight_dtype if has_weight else None, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + +def _get_fast_ptr_fused_add_rmsnorm_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + N: int, + device_index: int, + stream_handle: int, + copy_bits: int, + use_async: bool, + tpr: int, + direct_gmem: bool, + assumed_align: int, + eps: float, +) -> Optional[_PtrFusedAddRmsnormFastLaunch]: + if not _fast_launch_enabled(): + return None + key = ( + "ptr_fused_add_fast", + id(compiled), + N, + dtype, + device_index, + int(stream_handle), + int(copy_bits), + bool(use_async), + int(tpr), + bool(direct_gmem), + int(assumed_align), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + ptr_x = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_res = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_w = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + + arg_m = _StableI32Arg(0) + arg_n = _StableI32Arg(N) + arg_ld_x = _StableI32Arg(N) + arg_eps = _StableF32Arg(eps) + + stream = cuda.CUstream(int(stream_handle)) + + executor = compiled.to(device_index) # type: ignore[attr-defined] + + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + ptr_res, + arg_m, + arg_n, + arg_ld_x, + stream, + arg_eps, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + global _FAST_LAUNCH_SUPPORTED + _FAST_LAUNCH_SUPPORTED = False + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_res, + arg_m, + arg_n, + arg_ld_x, + arg_eps, + stream, + *adapted_args, + ) + + launcher = _PtrFusedAddRmsnormFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_res=ptr_res, + arg_m=arg_m, + arg_n=arg_n, + arg_ld_x=arg_ld_x, + arg_eps=arg_eps, + stream=stream, + assumed_align=assumed_align, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + +# Local helpers for reduction, dtype mapping, and coordinate/predicate utilities. +# +# NOTE: Avoid `from . import ...` imports here: CuTeDSL's AST preprocessor may +# mishandle that form (module=None in the AST). Use fully-qualified imports. +from kernelagent_oink.blackwell import lite_quack as qutils # noqa: E402 +from kernelagent_oink.blackwell.lite_quack import ( # noqa: E402 + TORCH2CUTE_DTYPE, + RMSNormBackward as BaseRMSNormBackward, + convert_from_dlpack as convert_from_dlpack_cute, + get_sm_count, + row_reduce, +) + + +# ------------------------- +# Copy helpers (allow up to 256b) +# ------------------------- + + +@cute.jit +def get_copy_atom_bw( + dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False +) -> cute.CopyAtom: + # cp.async (SIMT) supports up to 128b per op; use 256b for sync when possible + max_bits = const_expr(128 if is_async else 256) + num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width)) + from cutlass.cute.nvgpu import cpasync + + # Prefer GLOBAL cache policy for bulk streaming reads at large M. + copy_op = ( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL) + if is_async + else cute.nvgpu.CopyUniversalOp() + ) + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@cute.jit +def copy_tiled( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, +) -> None: + atom = get_copy_atom_bw(src.element_type, num_copy_elems, is_async) + cute.copy(atom, src, dst, pred=pred) + + +# ------------------------- +# RMSNorm Kernel (SM100) +# ------------------------- + + +class RMSNormSM100: + def __init__( + self, + N: int, + dtype: type[cutlass.Numeric], + stage: Optional[int] = None, + *, + copy_bits: int = 128, + use_async: bool = True, + direct_gmem: bool = False, + ): + self.N = N + self.dtype = dtype + # Match Quack default for RMSNorm: stage = 1 unless explicitly overridden + self.stage = 1 if stage is None else stage + self.reduction_dtype = cutlass.Float32 + self.copy_bits = int(copy_bits) + self.use_async = bool(use_async) + self.direct_gmem = bool(direct_gmem) + + def _threads_per_row(self) -> int: + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + pass + # Tune mid-size buckets for large-M rows. + N = self.N + # DSv3 MLA (padded/strided) hot shape. Prefer a threads-per-row that + # makes the tile width exactly match N with 128b vectors (bf16/fp16), + # avoiding the ~33% padded work from rounding 1536 -> 2048. + if N == 1536 and self.dtype.width == 16: + return 96 + # DSv3 default hidden size (7168). Choose a threads-per-row that matches + # the selected vector width to avoid padded work. Using 224 threads/row + # yields exact tiles for all supported copy widths we use on SM100: + # - 64b copies (vec=4 for bf16/fp16): 7168/4 = 1792 = 8 * 224 + # - 128b copies (vec=8 for bf16/fp16): 7168/8 = 896 = 4 * 224 + # - 256b copies (vec=16 for bf16/fp16): 7168/16 = 448 = 2 * 224 + # + if N == 7168 and self.dtype.width == 16: + return 224 + # DSv3-ish N buckets (6144/8192): use larger threads/row so each thread + # holds fewer elements in registers. For 256b vectors, pick a threads/row + # that yields an exact tile without padding. + if self.dtype.width == 16: + if N == 6144: + if self.copy_bits >= 256: + return 192 + if self.copy_bits <= 128: + return 256 + if N == 8192: + return 256 + # For small-N, use at least one full warp per row. The kernel + # implementation assumes one row per CTA; returning <32 here can + # produce multi-row tiles (cols_per_block > 1) which is not supported. + if N <= 1024: + return 32 + elif N <= 4096: + return 128 + elif N <= 8192: + # Allow an override (used by 2-rows/CTA path for N≈6k/8k) + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + return 128 + elif N <= 16384: + return 256 + else: + return 256 + + def _cluster_n(self) -> int: + try: + return self._cluster_n_override # type: ignore[attr-defined] + except Exception: + pass + N = self.N + # Default policy + if N <= 8192: + return 1 + if const_expr(self.dtype.width == 16): + if N <= 16 * 1024: + return 2 + elif N <= 32 * 1024: + return 2 + elif N <= 64 * 1024: + return 4 + elif N <= 128 * 1024: + return 8 + else: + return 16 + else: + if N <= 32 * 1024: + return 1 + elif N <= 64 * 1024: + return 2 + elif N <= 128 * 1024: + return 4 + elif N <= 256 * 1024: + return 8 + else: + return 16 + + def _num_threads(self) -> int: + # Favor 128 threads up to N=16k to reduce per-row partitioning overhead. + # This keeps cols_per_block=1 at N=8192 (bf16), which benchmarks faster for large-M. + try: + return self._nt_override # type: ignore[attr-defined] + except Exception: + if self.N == 1536 and self.dtype.width == 16: + return 96 + if self.N == 7168 and self.dtype.width == 16: + return 224 + if self.dtype.width == 16: + if self.N == 6144: + if self.copy_bits >= 256: + return 192 + if self.copy_bits <= 128: + return 256 + if self.N == 8192: + return 256 + if self.N <= 1024: + return 32 + return 128 if self.N <= 16384 else 256 + + def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]: + vecsize = num_copy_bits // self.dtype.width + num_threads = self._num_threads() + assert num_threads % cute.arch.WARP_SIZE == 0 + tpr = self._threads_per_row() + cluster_n = self._cluster_n() + # Allow tails: compute number of vector columns with ceil + num_cols_vec = cute.ceil_div(self.N, vecsize) + num_blocks_N = cute.ceil_div(num_cols_vec, tpr * cluster_n) + cols_per_block = num_threads // tpr + tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr) + tv_layout = cute.make_layout( + ((tpr, cols_per_block), (vecsize, num_blocks_N)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * tpr), + ), + ) + return tiler_mn, tv_layout + + def _smem_bytes(self, tiler_mn, num_warps) -> int: + # smem for X tile (+ residual if present) + reduction buffers + mbar(s) + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) + + self.stage + * num_warps + * self._cluster_n() + * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) + ) + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + # Make last dim static (N) + semistatic_shape = (*mX.shape[:-1], self.N) + + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=256 // t.element_type.width), + t.stride[1], + ) + + mX, mRes, mO, mResO = [ + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) + if const_expr(t is not None) + else None + for t in (mX, mRes, mO, mResO) + ] + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + + copy_bits = int(self.copy_bits) + tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + threads_per_row = ( + tv_layout.shape[0][0] + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._threads_per_row() + ) + warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) + cluster_n = self._cluster_n() + + if const_expr(mW is not None): + mW = cute.make_tensor( + mW.iterator, + cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mB is not None): + mB = cute.make_tensor( + mB.iterator, + cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mRstd is not None): + mRstd = cute.make_tensor( + mRstd.iterator, + cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))), + ) + + # No SMEM reload mode switch; overlap is controlled in the K-loop path + + # Compute smem usage considering staged buffers. + # + # In direct-gmem mode, we skip the gmem->smem tiles entirely and only + # keep the reduction buffers in shared memory. + stage_bufs = 2 if self.stage > 1 else 1 + tile_bytes_x = ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs + if const_expr(not self.direct_gmem) + else 0 + ) + tile_bytes_res = ( + cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) + * stage_bufs + if const_expr(mRes is not None and not self.direct_gmem) + else 0 + ) + red_bytes = ( + self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8) + ) + # mbarriers are only allocated/used for cluster_n>1. Some CuTeDSL builds + # require mbarrier state to be 16B-aligned in shared memory; account for + # the alignment padding when computing dynamic smem bytes. + smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes + if cluster_n > 1: + # Align up to 16B before placing the mbarrier array. + smem_bytes = ((smem_bytes + 15) // 16) * 16 + smem_bytes += self.stage * (cutlass.Int64.width // 8) + + kernel = ( + self.kernel( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + const_expr(cluster_n), + const_expr(num_warps), + const_expr(warps_per_row), + const_expr(threads_per_row), + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1], + block=[num_threads, 1, 1], + cluster=([1, cluster_n, 1] if cluster_n > 1 else None), + smem=smem_bytes, + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_w: Optional[cute.Pointer], + ptr_b: Optional[cute.Pointer], + ptr_res: Optional[cute.Pointer], + ptr_out: cute.Pointer, + ptr_res_out: Optional[cute.Pointer], + ptr_rstd: Optional[cute.Pointer], + M: Int32, + N_dyn: Int32, + ld: Int32, + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + """Pointer-based entrypoint to reuse the existing RMSNorm schedule. + + This reconstructs cute.Tensor views from raw pointers plus sizes, + avoiding any DLPack conversions at the Python boundary. + """ + # Use a dynamic N for the leading-dimension stride so that the + # subsequent cute.assume(...) in __call__ sees a dynamic expression + # rather than a plain Python int. + # The compile-time N for the kernel (self.N) is still used to + # specialize the schedule. + # Assume row-major [M, N] with an arbitrary leading-dimension stride + # (common for padded-row / packed-attention layouts). + layout_mn = cute.make_layout((M, N_dyn), stride=(ld, 1)) + layout_n = cute.make_layout((N_dyn,), stride=(1,)) + layout_m = cute.make_layout((M,), stride=(1,)) + + mX = cute.make_tensor(ptr_x, layout_mn) + mO = cute.make_tensor(ptr_out, layout_mn) + + mRes = ( + cute.make_tensor(ptr_res, layout_mn) + if const_expr(ptr_res is not None) + else None + ) + mResO = ( + cute.make_tensor(ptr_res_out, layout_mn) + if const_expr(ptr_res_out is not None) + else None + ) + mW = ( + cute.make_tensor(ptr_w, layout_n) if const_expr(ptr_w is not None) else None + ) + mB = ( + cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None + ) + mRstd = ( + cute.make_tensor(ptr_rstd, layout_m) + if const_expr(ptr_rstd is not None) + else None + ) + + # Reuse the main JIT entry to launch the scheduled kernel. + self.__call__(mX, mW, mB, mRes, mO, mResO, mRstd, stream, eps) + + @cute.jit + def launch_from_ptrs_fused_add_inplace( + self, + ptr_x: cute.Pointer, + ptr_w: cute.Pointer, + ptr_res: cute.Pointer, + M: Int32, + N_dyn: Int32, + ld_x: Int32, + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + """Pointer-based entrypoint for vLLM-style fused_add_rms_norm semantics. + + This specialized entrypoint supports: + - `x` / output with an arbitrary leading-dimension stride (`ld_x`), and + - `residual` / residual-out as a contiguous [M, N] tensor (ld_res = N). + + Both `x` and `residual` are updated in-place: + residual <- x + residual + x <- RMSNorm(residual) * weight + """ + layout_x = cute.make_layout((M, N_dyn), stride=(ld_x, 1)) + layout_res = cute.make_layout((M, N_dyn), stride=(N_dyn, 1)) + layout_n = cute.make_layout((N_dyn,), stride=(1,)) + + mX = cute.make_tensor(ptr_x, layout_x) + mO = cute.make_tensor(ptr_x, layout_x) + mRes = cute.make_tensor(ptr_res, layout_res) + mResO = cute.make_tensor(ptr_res, layout_res) + mW = cute.make_tensor(ptr_w, layout_n) + + self.__call__( + mX, + mW, + None, # bias + mRes, + mO, + mResO, + None, # rstd + stream, + eps, + ) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + cluster_n: cutlass.Constexpr[int], + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(cluster_n > 1): + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + else: + cta_rank_in_cluster = const_expr(0) + n_off = cta_rank_in_cluster * tiler_mn[1] + + smem = cutlass.utils.SmemAllocator() + # Allocate one or two SMEM buffers depending on stage depth + sX0 = ( + smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(not self.direct_gmem) + else None + ) + sX1 = ( + smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(self.stage > 1 and not self.direct_gmem) + else None + ) + sRes0 = ( + smem.allocate_tensor( + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(mRes is not None and not self.direct_gmem) + else None + ) + sRes1 = ( + smem.allocate_tensor( + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(mRes is not None and self.stage > 1 and not self.direct_gmem) + else None + ) + + # Reduction buffers + mbar for cluster reduce (reused by row_reduce helper) + red_layout = cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), + order=(1, 0, 2), + ) + reduction_buffer = smem.allocate_tensor( + self.reduction_dtype, red_layout, byte_alignment=4 + ) + if const_expr(cluster_n > 1): + # Some CuTeDSL builds appear sensitive to the shared-memory alignment of + # mbarrier state. `SmemAllocator.allocate_array` does not currently + # expose an alignment parameter, so allocate an Int64 tensor with an + # explicit alignment and pass its iterator as the pointer. + mbar_tensor = smem.allocate_tensor( + cutlass.Int64, + cute.make_layout((self.stage,), stride=(1,)), + byte_alignment=16, + ) + mbar_ptr = mbar_tensor.iterator + else: + mbar_ptr = None + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + limit_k = shape[1] - n_off + + # Tiled copy setup + num_copy_elems_X = tv_layout.shape[1][0] + use_async = const_expr( + self.use_async and self.N >= 1024 and not self.direct_gmem + ) + copy_atom = get_copy_atom_bw( + mX.element_type, num_copy_elems_X, is_async=use_async + ) + thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx) + + # Tail predicate for the N dimension (when tile width > N). Reuse this + # for W/B loads so we never read past the end of those 1D tensors. + is_even_N_wb = const_expr(shape[1] == tiler_mn[1] * cluster_n) + if const_expr(not is_even_N_wb): + cX0 = cute.local_tile(idX, tiler_mn, (0, 0)) + tXp_wb = qutils.predicate_k(thr_copy.partition_S(cX0), limit=limit_k) + else: + tXp_wb = None + + # Weight/bias loads: + # + # - Direct-GMEM schedule: load weight/bias up front to hide latency. + # - Staged SMEM schedule: loading after the reduction reduces register + # pressure during the long-scoreboard reduction phase (better for large-M), + # but it measurably hurts small-M latency for the non-fused (no residual, + # no bias) case. For that specific case, prefetch weight up front as well. + tXrW = None + tXrB = None + prefetch_w_early = bool( + mW is not None and (self.direct_gmem or (mRes is None and mB is None)) + ) + if const_expr(prefetch_w_early): + gW = cute.local_tile( + qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0) + ) + tXgW = thr_copy.partition_S(gW) + tXrW = cute.make_fragment_like(tXgW) + if const_expr(not is_even_N_wb): + tXrW.fill(0) + cute.copy( + get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), + tXgW, + tXrW, + pred=tXp_wb, + ) + if const_expr(self.direct_gmem and mB is not None): + gB = cute.local_tile( + qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0) + ) + tXgB = thr_copy.partition_S(gB) + tXrB = cute.make_fragment_like(tXgB) + if const_expr(not is_even_N_wb): + tXrB.fill(0) + cute.copy( + get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), + tXgB, + tXrB, + pred=tXp_wb, + ) + + # Non-persistent per-CTA execution (one tile in M) + self._init_cluster(tidx, mbar_ptr) + + mX_i, mRes_i, mO_i, mResO_i = [ + qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) + if t is not None + else None + for t in (mX, mRes, mO, mResO) + ] + mX_i, mRes_i, mO_i, mResO_i = [ + qutils.domain_offset_i64((0, n_off), t) if t is not None else None + for t in (mX_i, mRes_i, mO_i, mResO_i) + ] + gX_i = cute.local_tile(mX_i, tiler_mn, (0, 0)) + gO_i = cute.local_tile(mO_i, tiler_mn, (0, 0)) + gRes_i = ( + cute.local_tile(mRes_i, tiler_mn, (0, 0)) + if const_expr(mRes is not None) + else None + ) + gResO_i = ( + cute.local_tile(mResO_i, tiler_mn, (0, 0)) + if const_expr(mResO is not None) + else None + ) + gRstd_i = ( + cute.local_tile(mRstd, tiler_mn, (bidx, 0)) + if const_expr(mRstd is not None) + else None + ) + cX_i = cute.local_tile(idX, tiler_mn, (bidx, 0)) + + # Common identity/row index partitions reused by both default and K-loop paths + tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None] + row_i = tXcX_i[0][0] + tXgRstd_i = ( + thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None + ) + + # Stage-2 intra-row K-loop cp.async ping-pong (two tiles). This reduces + # per-thread fragment size and can improve memory-latency hiding for + # N=7168 at large M. It is enabled by setting `stage=2` when constructing + # the RMSNormSM100 op (see `_fused_add_rmsnorm_forward_ptr_inplace`). + if const_expr( + self.stage > 1 + and not self.direct_gmem + and use_async + and cluster_n == 1 + and shape[1] == 7168 + ): + vecsize = tv_layout.shape[1][0] + tpr = threads_per_row + target_tile_n = const_expr(4096) + tile_factor = const_expr(target_tile_n // (vecsize * tpr)) + if const_expr(tile_factor > 0): + tile_n = vecsize * tpr * tile_factor + num_tiles = cute.ceil_div(shape[1], tile_n) + + tiler_mn_tile = (tiler_mn[0], tile_n) + sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0)) + sX1_tile = cute.local_tile(sX1, tiler_mn_tile, (0, 0)) + sRes0_tile = ( + cute.local_tile(sRes0, tiler_mn_tile, (0, 0)) + if const_expr(mRes is not None) + else None + ) + sRes1_tile = ( + cute.local_tile(sRes1, tiler_mn_tile, (0, 0)) + if const_expr(mRes is not None) + else None + ) + + tv_layout_tile = cute.make_layout( + ((tpr, tiler_mn[0]), (vecsize, tile_factor)), + stride=( + (vecsize * tiler_mn[0], 1), + (tiler_mn[0], tiler_mn[0] * vecsize * tpr), + ), + ) + thr_copy_tile = cute.make_tiled_copy( + copy_atom, tv_layout_tile, tiler_mn_tile + ).get_slice(tidx) + + # Accumulate per-thread partial sums across tiles; reduce once. + sum_sq_thread = cute.Float32(0.0) + + # Preload tile 0 into sX0/sRes0. + k_off0 = const_expr(0) * tile_n + gX_0 = cute.local_tile( + qutils.domain_offset_i64((0, k_off0), mX_i), tiler_mn_tile, (0, 0) + ) + tXgX_0 = thr_copy_tile.partition_S(gX_0) + tXsX_0 = thr_copy_tile.partition_D(sX0_tile) + cX_0 = cute.local_tile( + cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, 0) + ) + tXc_0 = thr_copy_tile.partition_S(cX_0) + tXp_0 = qutils.predicate_k(tXc_0, limit=limit_k) + + tXp_ping = tXp_0 + tXp_pong = tXp_0 + + if row_i < shape[0]: + copy_tiled( + tXgX_0, + tXsX_0, + num_copy_elems=vecsize, + is_async=True, + pred=tXp_0, + ) + if const_expr(mRes is not None): + gRes_0 = cute.local_tile( + qutils.domain_offset_i64((0, k_off0), mRes_i), + tiler_mn_tile, + (0, 0), + ) + tXgRes_0 = thr_copy_tile.partition_S(gRes_0) + tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile) + copy_tiled( + tXgRes_0, + tXsRes_0, + num_copy_elems=vecsize, + is_async=True, + pred=tXp_0, + ) + cute.arch.cp_async_commit_group() + + for t in cutlass.range_constexpr(num_tiles): + next_t = t + 1 + if next_t < num_tiles: + k_off_n = next_t * tile_n + gX_n = cute.local_tile( + qutils.domain_offset_i64((0, k_off_n), mX_i), + tiler_mn_tile, + (0, 0), + ) + tXgX_n = thr_copy_tile.partition_S(gX_n) + cX_n = cute.local_tile( + cute.domain_offset((0, k_off_n), cX_i), + tiler_mn_tile, + (0, 0), + ) + tXc_n = thr_copy_tile.partition_S(cX_n) + tXp_n = qutils.predicate_k(tXc_n, limit=limit_k) + + if const_expr((t % 2) == 0): + tXsX_n = thr_copy_tile.partition_D(sX1_tile) + tXsRes_n = ( + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None + ) + tXp_pong = tXp_n + else: + tXsX_n = thr_copy_tile.partition_D(sX0_tile) + tXsRes_n = ( + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None + ) + tXp_ping = tXp_n + + if row_i < shape[0]: + copy_tiled( + tXgX_n, + tXsX_n, + num_copy_elems=vecsize, + is_async=True, + pred=tXp_n, + ) + if const_expr(mRes is not None): + gRes_n = cute.local_tile( + qutils.domain_offset_i64((0, k_off_n), mRes_i), + tiler_mn_tile, + (0, 0), + ) + tXgRes_n = thr_copy_tile.partition_S(gRes_n) + copy_tiled( + tXgRes_n, + tXsRes_n, + num_copy_elems=vecsize, + is_async=True, + pred=tXp_n, + ) + cute.arch.cp_async_commit_group() + + cute.arch.cp_async_wait_group(1 if next_t < num_tiles else 0) + + # Current tile buffer (ping/pong). + if const_expr((t % 2) == 0): + tXsX_cur = thr_copy_tile.partition_D(sX0_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None + ) + pred_cur = tXp_ping + else: + tXsX_cur = thr_copy_tile.partition_D(sX1_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None + ) + pred_cur = tXp_pong + + k_off = t * tile_n + gX_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mX_i), + tiler_mn_tile, + (0, 0), + ) + tXgX_t = thr_copy_tile.partition_S(gX_t) + tXrX_t = cute.make_fragment_like(tXgX_t) + cute.autovec_copy(tXsX_cur, tXrX_t) + x_t = tXrX_t.load().to(cute.Float32) + if const_expr(mRes is not None): + gRes_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mRes_i), + tiler_mn_tile, + (0, 0), + ) + tXgRes_t = thr_copy_tile.partition_S(gRes_t) + tXrRes_t = cute.make_fragment_like(tXgRes_t) + cute.autovec_copy(tXsRes_cur, tXrRes_t) + x_t += tXrRes_t.load().to(cute.Float32) + + if const_expr(mResO is not None): + gResO_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mResO_i), + tiler_mn_tile, + (0, 0), + ) + tXgResO_t = thr_copy_tile.partition_D(gResO_t) + tXrResO_t = cute.make_fragment_like(tXgResO_t) + tXrResO_t.store(x_t.to(tXrResO_t.element_type)) + if row_i < shape[0]: + copy_tiled( + tXrResO_t, + tXgResO_t, + num_copy_elems=vecsize, + is_async=False, + pred=pred_cur, + ) + + sum_sq_thread = sum_sq_thread + (x_t * x_t).reduce( + cute.ReductionOp.ADD, + init_val=0.0, + reduction_profile=0, + ) + + sum_sq = row_reduce( + sum_sq_thread, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True) + + if const_expr(mRstd is not None): + if tXcX_i[0][1] == 0 and row_i < shape[0]: + tXgRstd_i[0] = rstd + + for t in cutlass.range_constexpr(num_tiles): + k_off = t * tile_n + cX_t = cute.local_tile( + cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, 0) + ) + tXc_t = thr_copy_tile.partition_S(cX_t) + tXp_t = qutils.predicate_k(tXc_t, limit=limit_k) + + if const_expr((t % 2) == 0): + tXsX_cur = thr_copy_tile.partition_D(sX0_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None + ) + else: + tXsX_cur = thr_copy_tile.partition_D(sX1_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None + ) + + gX_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mX_i), + tiler_mn_tile, + (0, 0), + ) + tXgX_t = thr_copy_tile.partition_S(gX_t) + tXrX_t = cute.make_fragment_like(tXgX_t) + cute.autovec_copy(tXsX_cur, tXrX_t) + x_t = tXrX_t.load().to(cute.Float32) + if const_expr(mRes is not None): + gRes_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mRes_i), + tiler_mn_tile, + (0, 0), + ) + tXgRes_t = thr_copy_tile.partition_S(gRes_t) + tXrRes_t = cute.make_fragment_like(tXgRes_t) + cute.autovec_copy(tXsRes_cur, tXrRes_t) + x_t += tXrRes_t.load().to(cute.Float32) + + y_t = x_t * rstd + if const_expr(mW is not None): + gW_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mW), + tiler_mn_tile, + (0, 0), + ) + tWgW_t = thr_copy_tile.partition_S(gW_t) + tWrW_t = cute.make_fragment_like(tWgW_t) + copy_tiled( + tWgW_t, + tWrW_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) + y_t = y_t * tWrW_t.load().to(cute.Float32) + if const_expr(mB is not None): + gB_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mB), + tiler_mn_tile, + (0, 0), + ) + tWgB_t = thr_copy_tile.partition_S(gB_t) + tWrB_t = cute.make_fragment_like(tWgB_t) + copy_tiled( + tWgB_t, + tWrB_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) + y_t = y_t + tWrB_t.load().to(cute.Float32) + + gO_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mO_i), + tiler_mn_tile, + (0, 0), + ) + tXgO_t = thr_copy_tile.partition_D(gO_t) + tXrO_t = cute.make_fragment_like(tXgO_t) + tXrO_t.store(y_t.to(tXrO_t.element_type)) + if row_i < shape[0]: + copy_tiled( + tXrO_t, + tXgO_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) + + return + + # Single-stage path: one-row-per-CTA + tXgX_i = thr_copy.partition_S(gX_i) + tXgRes_i = ( + thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None + ) + tXgO_i = thr_copy.partition_D(gO_i) + tXgResO_i = ( + thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None + ) + # tXgRstd_i / tXcX_i / row_i prepared above + is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n) + tXpX_i = ( + qutils.predicate_k(thr_copy.partition_S(cX_i), limit=limit_k) + if not is_even_N_i + else None + ) + + tXrX = cute.make_fragment_like(tXgX_i) + tXrRes = ( + cute.make_fragment_like(tXgRes_i) if const_expr(mRes is not None) else None + ) + if const_expr(self.direct_gmem): + if const_expr(not is_even_N_i): + tXrX.fill(0) + if const_expr(tXrRes is not None): + tXrRes.fill(0) + if row_i < shape[0]: + cute.copy(copy_atom, tXgX_i, tXrX, pred=tXpX_i) + if const_expr(tXrRes is not None): + cute.copy(copy_atom, tXgRes_i, tXrRes, pred=tXpX_i) + else: + # If N is not a multiple of the tile width, the predicated gmem->smem + # copies leave out-of-bounds lanes uninitialized. Clear the SMEM tile + # so masked lanes read as 0 for reduction/output. + if const_expr(not is_even_N_i): + thr_copy.partition_D(sX0).fill(0) + if const_expr(mRes is not None): + thr_copy.partition_D(sRes0).fill(0) + + if row_i < shape[0]: + cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i) + if const_expr(mRes is not None): + cute.copy( + copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i + ) + if const_expr(use_async): + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + cute.autovec_copy(thr_copy.partition_D(sX0), tXrX) + if const_expr(tXrRes is not None): + cute.autovec_copy(thr_copy.partition_D(sRes0), tXrRes) + x_red = tXrX.load().to(cute.Float32) + if const_expr(tXrRes is not None): + x_red += tXrRes.load().to(cute.Float32) + + if const_expr(mResO is not None): + tXrResO = cute.make_fragment_like(tXgResO_i) + tXrResO.store(x_red.to(tXrResO.element_type)) + if row_i < shape[0]: + cute.copy( + get_copy_atom_bw( + tXrResO.element_type, num_copy_elems_X, is_async=False + ), + tXrResO, + tXgResO_i, + pred=tXpX_i, + ) + + sum_sq = row_reduce( + x_red * x_red, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True) + + if const_expr(mRstd is not None): + if ( + tXcX_i[0][1] == 0 + and row_i < shape[0] + and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXgRstd_i[0] = rstd + + if const_expr(not self.direct_gmem and (mRes is not None or mB is not None)): + # Load weight/bias after the reduction so they don't inflate register + # pressure during the long-scoreboard reduction phase (helping occupancy + # when registers are the limiting factor). + if const_expr(mW is not None): + gW = cute.local_tile( + qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0) + ) + tXgW = thr_copy.partition_S(gW) + tXrW = cute.make_fragment_like(tXgW) + if const_expr(not is_even_N_wb): + tXrW.fill(0) + cute.copy( + get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), + tXgW, + tXrW, + pred=tXp_wb, + ) + if const_expr(mB is not None): + gB = cute.local_tile( + qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0) + ) + tXgB = thr_copy.partition_S(gB) + tXrB = cute.make_fragment_like(tXgB) + if const_expr(not is_even_N_wb): + tXrB.fill(0) + cute.copy( + get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), + tXgB, + tXrB, + pred=tXp_wb, + ) + + # Reuse `x_red` (x + residual, in fp32) for the output path so we don't + # keep both `tXrX` and `tXrRes` fragments live across the reduction. + y = x_red * rstd + if const_expr(mW is not None): + y = y * tXrW.load().to(cute.Float32) + if const_expr(mB is not None): + y = y + tXrB.load().to(cute.Float32) + + tXrO = cute.make_fragment_like(tXgO_i) + tXrO.store(y.to(tXrO.element_type)) + if row_i < shape[0]: + cute.copy( + get_copy_atom_bw(tXrO.element_type, num_copy_elems_X, is_async=False), + tXrO, + tXgO_i, + pred=tXpX_i, + ) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + cluster_n: cutlass.Constexpr[int], + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + threads_per_row: cutlass.Constexpr[int], + ): + self._kernel_impl( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + cluster_n, + num_warps, + warps_per_row, + threads_per_row, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + ): + copy_bits = int(self.copy_bits) + tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) + num_threads = self._num_threads() + num_warps = num_threads // cute.arch.WARP_SIZE + threads_per_row = self._threads_per_row() + warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) + cluster_n = self._cluster_n() + self._kernel_impl( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + const_expr(cluster_n), + const_expr(num_warps), + const_expr(warps_per_row), + const_expr(threads_per_row), + ) + + @cute.jit + def _init_cluster(self, tidx: cutlass.Int32, mbar_ptr: Optional[cute.Pointer]): + if const_expr(mbar_ptr is not None): + if tidx < self.stage: + cute.arch.mbarrier_init(mbar_ptr + tidx, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + +def _can_use_ptr_path( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + residual: Optional[Tensor], +) -> bool: + """Fast path precondition for the pointer-based CuTeDSL entry. + + We require a row-major 2D layout where the last dimension is + contiguous (stride(1) == 1). The leading dimension (stride(0)) + may be larger than N (padded-row / packed-attention layouts), + and is passed to the kernel as `ld`. + """ + if x.stride(1) != 1: + return False + # All participating tensors are interpreted as the same element type + # (derived from x.dtype) in the pointer-based path. If dtypes differ, + # we'd read the wrong bit patterns and silently produce incorrect output. + if residual is not None and residual.dtype != x.dtype: + return False + if weight is not None and weight.dtype != x.dtype: + # Allow the common "Quack-style" API where weights are fp32 even when + # activations are bf16/fp16. The pointer path constructs a weight tensor + # view with the correct element type (fp32) inside the compiled graph. + if weight.dtype is not torch.float32: + return False + if x.dtype not in (torch.float16, torch.bfloat16): + return False + if bias is not None and bias.dtype != x.dtype: + return False + # The kernel assumes `ld` satisfies a divisibility constraint used by + # cute.assume(..., divby=...) for vectorization. + elem_bits = TORCH2CUTE_DTYPE[x.dtype].width + divby = 256 // elem_bits + if (x.stride(0) % divby) != 0: + return False + # The kernel uses 128-bit vectorized copies (16B). Require at least 16B + # alignment on all participating tensors to avoid misaligned global loads. + if (x.data_ptr() % 16) != 0: + return False + if residual is not None and residual.stride(1) != 1: + return False + if residual is not None and residual.stride(0) != x.stride(0): + return False + if residual is not None and (residual.data_ptr() % 16) != 0: + return False + if weight is not None and not weight.is_contiguous(): + return False + if bias is not None and not bias.is_contiguous(): + return False + if weight is not None: + # For fp32 weights we use 256b universal copies (32B) by default. + # Require 32B alignment so the compiler can safely vectorize loads. + if weight.dtype is torch.float32: + if (weight.data_ptr() % 32) != 0: + return False + else: + if (weight.data_ptr() % 16) != 0: + return False + if bias is not None and (bias.data_ptr() % 16) != 0: + return False + return True + + +def _can_use_ptr_path_fused_add_inplace( + x: Tensor, + weight: Tensor, + residual: Tensor, +) -> bool: + """Fast-path precondition for fused_add_rmsnorm_forward_inplace. + + We allow the common vLLM layout where: + - `x` is strided/padded row-major (stride(1) == 1, stride(0) >= N) + - `residual` is contiguous row-major (stride(0) == N) + """ + if x.stride(1) != 1: + return False + if residual.dtype != x.dtype: + return False + if weight.dtype != x.dtype: + return False + if residual.stride(1) != 1: + return False + if not residual.is_contiguous(): + return False + if not weight.is_contiguous(): + return False + + dtype = TORCH2CUTE_DTYPE[x.dtype] + divby = 256 // dtype.width + if (x.stride(0) % divby) != 0: + return False + if (residual.stride(0) % divby) != 0: + return False + + if (x.data_ptr() % 16) != 0: + return False + if (residual.data_ptr() % 16) != 0: + return False + if (weight.data_ptr() % 16) != 0: + return False + return True + + +def _can_use_ptr_path_bwd( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, +) -> bool: + """Fast-path precondition for the pointer-based RMSNorm backward entry. + + This path is only used for the common Quack-style signature: + - no bias gradient + - no residual / dresidual_out + - weight is either the same dtype as x, or fp32 for bf16/fp16 activations + """ + if x.dim() != 2 or dout.dim() != 2: + return False + if rstd.dim() != 1: + return False + if x.shape != dout.shape: + return False + if rstd.numel() != x.shape[0]: + return False + # SM100 backward kernel assumes N is divisible by 8 (for 256b fp32 stores + # into dw_partial rows). + if (x.shape[1] % 8) != 0: + return False + if x.stride(1) != 1 or dout.stride(1) != 1: + return False + if dout.stride(0) != x.stride(0): + return False + if dout.dtype != x.dtype: + return False + if rstd.dtype != torch.float32 or not rstd.is_contiguous(): + return False + if weight is None: + return False + if weight.dim() != 1 or weight.shape[0] != x.shape[1]: + return False + if not weight.is_contiguous(): + return False + if weight.dtype != x.dtype: + if weight.dtype is not torch.float32: + return False + if x.dtype not in (torch.float16, torch.bfloat16): + return False + + dtype = TORCH2CUTE_DTYPE[x.dtype] + divby = 256 // dtype.width + if (x.stride(0) % divby) != 0: + return False + + if (x.data_ptr() % 16) != 0: + return False + if (dout.data_ptr() % 16) != 0: + return False + # Torch CUDA allocations are typically >=256B aligned, but keep the check + # explicit so we never assume tighter alignment than is true. + if (rstd.data_ptr() % 4) != 0: + return False + if (weight.data_ptr() % (32 if weight.dtype is torch.float32 else 16)) != 0: + return False + return True + + +def _rmsnorm_forward_ptr( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + residual: Optional[Tensor], + eps: float, + store_rstd: bool, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Pointer-based RMSNorm forward that bypasses DLPack entirely. + + This path reconstructs cute.Tensor views from raw device pointers + and explicit layouts inside the JIT graph, avoiding any runtime + DLPack conversions while reusing the tuned RMSNormSM100 schedule. + """ + assert x.is_cuda + assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." + M, N = x.shape + + # Preserve the input's 2D stride so downstream users that rely on + # padded-row layouts (stride0 > N) continue to see the expected layout. + out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + residual_out: Optional[Tensor] = None + rstd: Optional[Tensor] = None + + if residual is not None: + residual_out = torch.empty_strided( + residual.shape, + residual.stride(), + device=residual.device, + dtype=residual.dtype, + ) + if store_rstd: + rstd = torch.empty(M, device=x.device, dtype=torch.float32) + + _rmsnorm_forward_ptr_into( + x=x, + weight=weight, + bias=bias, + residual=residual, + out=out, + residual_out=residual_out, + rstd=rstd, + eps=eps, + ) + return out, rstd, residual_out + + +def _rmsnorm_forward_ptr_into( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + residual: Optional[Tensor], + out: Tensor, + residual_out: Optional[Tensor], + rstd: Optional[Tensor], + eps: float, +) -> None: + """Internal helper that launches the pointer-based kernel into preallocated outputs. + + This enables integration into frameworks like vLLM that manage their + own buffers and prefer in-place or out-parameter semantics. + """ + assert x.is_cuda + assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." + M, N = x.shape + device_index = x.get_device() + dtype = TORCH2CUTE_DTYPE[x.dtype] + + if bias is None and residual is None and residual_out is None and rstd is None: + # Fast-launch path: cache packed args and update pointers/scalars in-place to + # avoid Python-side argument marshalling overhead that dominates small-batch cases. + # + # If fast-launch is disabled (or CuTeDSL internals changed), we fall back + # to calling the compiled function directly. + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + has_weight = weight is not None + + weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] if has_weight else None + + # Schedule selection (pointer fast path). + # + # Goals: + # - Keep vLLM inference fast path (contiguous/padded row-major) fast. + # - Enable higher vector widths when all participating pointers are 32B-aligned. + # - Prefer direct-GMEM for SM100-friendly hidden sizes to reduce SMEM/barrier + # overhead, especially for small/medium-M cases. + direct_gmem = _direct_gmem_from_policy( + default=bool(dtype.width == 16 and N in {4096, 6144, 7168, 8192}) + ) + use_async = not direct_gmem + + can_use_256 = bool( + dtype.width == 16 + and (x.data_ptr() % 32) == 0 + and (out.data_ptr() % 32) == 0 + and (not has_weight or (weight.data_ptr() % 32) == 0) # type: ignore[union-attr] + ) + default_copy_bits = 256 if can_use_256 else 128 + # Quack-style fp32-weight policy: cap the *widest* dtype to 128b, so when + # weights are fp32 we use 64b activation vectors (helps register pressure). + if dtype.width == 16 and weight_dtype is not None and weight_dtype.width == 32: + default_copy_bits = 64 + copy_bits = _copy_bits_from_policy( + default=default_copy_bits, can_use_256=can_use_256 + ) + assumed_align = 32 if copy_bits >= 256 else 16 + + stage = 1 + if ( + _ENABLE_STAGE2 + and dtype.width == 16 + and N == 7168 + and (not direct_gmem) + and M >= 4096 + ): + stage = 2 + + compiled_key = ( + "ptr", + N, + dtype, + weight_dtype, + False, # residual + has_weight, + False, # bias + False, # residual_out + False, # rstd + stage, + int(copy_bits), + bool(use_async), + bool(direct_gmem), + int(assumed_align), + device_index, + ) + compiled = _PTR_COMPILE_CACHE.get(compiled_key) + if compiled is None: + op = RMSNormSM100( + N, + dtype, + stage=stage, + copy_bits=int(copy_bits), + use_async=bool(use_async), + direct_gmem=bool(direct_gmem), + ) + ld_val = int(x.stride(0)) + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_out = rt.make_ptr( + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_w = ( + rt.make_ptr( + weight_dtype or dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if has_weight + else None + ) + stream = cuda.CUstream(stream_handle) + ld = Int32(ld_val) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + None, # ptr_b + None, # ptr_res + ptr_out, + None, # ptr_res_out + None, # ptr_rstd + Int32(M), + Int32(N), + ld, + stream, + Float32(eps), + ) + _PTR_COMPILE_CACHE[compiled_key] = compiled + + launcher = _get_fast_ptr_rmsnorm_launcher( + compiled=compiled, + dtype=dtype, + weight_dtype=weight_dtype, + N=N, + device_index=device_index, + stream_handle=stream_handle, + has_weight=has_weight, + assumed_align=assumed_align, + eps=eps, + ) + ld_val = int(x.stride(0)) + if launcher is not None: + launcher.launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld_val, eps=eps) + return + + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_out = rt.make_ptr( + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_w = ( + rt.make_ptr( + weight_dtype or dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if has_weight + else None + ) + stream = cuda.CUstream(stream_handle) + ld = Int32(ld_val) + compiled( + ptr_x, + ptr_w, + None, # ptr_b + None, # ptr_res + ptr_out, + None, # ptr_res_out + None, # ptr_rstd + Int32(M), + Int32(N), + ld, + stream, + Float32(eps), + ) + return + + # General path (supports bias/residual/rstd, but is slower to launch). + # + # Keep the same schedule-selection policy as the fast path so correctness-only + # features (bias/residual/rstd) don't accidentally fall off a performance cliff. + weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] if weight is not None else None + direct_gmem = _direct_gmem_from_policy( + default=bool(dtype.width == 16 and N in {4096, 6144, 7168, 8192}) + ) + use_async = not direct_gmem + can_use_256 = bool( + dtype.width == 16 + and (x.data_ptr() % 32) == 0 + and (out.data_ptr() % 32) == 0 + and (weight is None or (weight.data_ptr() % 32) == 0) + and (bias is None or (bias.data_ptr() % 32) == 0) + and (residual is None or (residual.data_ptr() % 32) == 0) + and (residual_out is None or (residual_out.data_ptr() % 32) == 0) + ) + default_copy_bits = 256 if can_use_256 else 128 + if dtype.width == 16 and weight_dtype is not None and weight_dtype.width == 32: + default_copy_bits = 64 + copy_bits = _copy_bits_from_policy( + default=default_copy_bits, can_use_256=can_use_256 + ) + assumed_align = 32 if copy_bits >= 256 else 16 + + stage = 1 + if ( + _ENABLE_STAGE2 + and dtype.width == 16 + and N == 7168 + and (not direct_gmem) + and M >= 4096 + ): + stage = 2 + + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + key = ( + "ptr", + N, + dtype, + weight_dtype, + residual is not None, + weight is not None, + bias is not None, + residual_out is not None, + rstd is not None, + stage, + int(copy_bits), + bool(use_async), + bool(direct_gmem), + int(assumed_align), + device_index, + ) + compiled = _PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = RMSNormSM100( + N, + dtype, + stage=stage, + copy_bits=int(copy_bits), + use_async=bool(use_async), + direct_gmem=bool(direct_gmem), + ) + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_out = rt.make_ptr( + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_res = ( + rt.make_ptr( + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if residual is not None + else None + ) + ptr_res_out = ( + rt.make_ptr( + dtype, + residual_out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if residual_out is not None + else None + ) + ptr_w = ( + rt.make_ptr( + weight_dtype or dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if weight is not None + else None + ) + ptr_b = ( + rt.make_ptr( + dtype, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if rstd is not None + else None + ) + stream = cuda.CUstream(stream_handle) + ld = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + ptr_b, + ptr_res, + ptr_out, + ptr_res_out, + ptr_rstd, + Int32(M), + Int32(N), + ld, + stream, + Float32(eps), + ) + _PTR_COMPILE_CACHE[key] = compiled + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_out = rt.make_ptr( + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_res = ( + rt.make_ptr( + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if residual is not None + else None + ) + ptr_res_out = ( + rt.make_ptr( + dtype, + residual_out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if residual_out is not None + else None + ) + ptr_w = ( + rt.make_ptr( + weight_dtype or dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if weight is not None + else None + ) + ptr_b = ( + rt.make_ptr( + dtype, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if rstd is not None + else None + ) + stream = cuda.CUstream(stream_handle) + ld = Int32(int(x.stride(0))) + compiled( + ptr_x, + ptr_w, + ptr_b, + ptr_res, + ptr_out, + ptr_res_out, + ptr_rstd, + Int32(M), + Int32(N), + ld, + stream, + Float32(eps), + ) + + +def _fused_add_rmsnorm_forward_ptr_inplace( + x: Tensor, + residual: Tensor, + weight: Tensor, + eps: float, +) -> None: + """Pointer-based fused_add_rmsnorm that updates `x` and `residual` in-place.""" + assert x.is_cuda + assert x.dim() == 2 + assert residual.is_cuda + assert residual.dim() == 2 + assert x.shape == residual.shape + + M, N = x.shape + device_index = x.get_device() + dtype = TORCH2CUTE_DTYPE[x.dtype] + stage = 1 + + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + + # Latency-optimized schedule for small-M cases: avoid the gmem->smem + # staging path (large dynamic smem + extra barriers) and load directly + # from gmem into registers. + copy_bits = 128 + # Use a direct-GMEM schedule (no staging SMEM tiles) for DSv3 hidden size + # (7168, bf16/fp16). This improves both: + # - small-M latency (fewer barriers + less dynamic shared memory), and + # - large-M bandwidth (lower overhead, better vectorization when 32B-aligned). + # + # This is a policy decision: it is tuned for DSv3's N=7168. If you want to + # benchmark other models/shapes, you can override it with: + # - OINK_RMSNORM_DIRECT_GMEM=0 (force staging/cp.async path) + # - OINK_RMSNORM_DIRECT_GMEM=1 (force direct-gmem path) + # Default direct-GMEM policy: + # - small/medium M: direct-GMEM reduces staging/barrier overhead + # - large M: staged cp.async tends to win on sustained bandwidth + direct_gmem = _direct_gmem_from_policy( + default=bool(dtype.width == 16 and N == 7168 and M <= 16384) + ) + use_async = not direct_gmem + tpr_override: Optional[int] = None + nt_override: Optional[int] = None + cluster_n_override: Optional[int] = None + direct_gmem_max_copy_bits: Optional[int] = None + + # Experimental stage-2 cp.async path (2-tile ping-pong) for N=7168. This is + # primarily about improving memory-latency hiding / reducing long-scoreboard + # stalls for large-M workloads. + if _ENABLE_STAGE2 and dtype.width == 16 and N == 7168 and M >= 4096: + stage = 2 + direct_gmem = False + use_async = True + + # Experimental ILP variant (clusters): split each row across 2 CTAs. + # + # NOTE: This is currently opt-in because some CuTeDSL builds exhibit + # instability with cluster launches for this specific schedule. To reduce + # the chance of accidental crashes, we require an additional explicit + # opt-in via `OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1`. + if _ENABLE_CLUSTER_ILP and not _ENABLE_STAGE2: + if dtype.width == 16 and N == 7168 and M >= 4096: + cluster_n_override = 2 + if direct_gmem: + # Cluster launches + direct-GMEM has exhibited reproducible compiler + # instability (segfaults) in some CuTeDSL builds, especially for the + # 256b vector path. Probe it out-of-process once so we can safely + # select a working copy width (or fall back to the staged SMEM path) + # instead of crashing the parent process. + max_bits = _probe_cluster_direct_gmem_max_copy_bits() + if max_bits == 0: + direct_gmem = False + use_async = True + else: + direct_gmem_max_copy_bits = max_bits + + # Experimental per-row partitioning: use 256 threads/row for N=7168 to + # increase concurrency/ILP (accepts a small tail-predicate region). + if _ENABLE_TPR256 and cluster_n_override is None and not _ENABLE_STAGE2: + if dtype.width == 16 and N == 7168 and M >= 4096: + tpr_override = 256 + nt_override = 256 + + can_use_256 = bool( + direct_gmem + and (direct_gmem_max_copy_bits is None or direct_gmem_max_copy_bits >= 256) + and dtype.width == 16 + and (x.data_ptr() % 32) == 0 + and (residual.data_ptr() % 32) == 0 + and (weight.data_ptr() % 32) == 0 + ) + assumed_align = 32 if can_use_256 else 16 + if can_use_256: + copy_bits = 256 + + copy_bits = _copy_bits_from_policy(default=copy_bits, can_use_256=can_use_256) + if copy_bits == 128: + assumed_align = 16 + elif copy_bits == 256 and can_use_256: + assumed_align = 32 + else: + copy_bits = 128 + assumed_align = 16 + + key = ( + "ptr_fused_add_inplace", + N, + dtype, + stage, + device_index, + copy_bits, + use_async, + tpr_override, + nt_override, + direct_gmem, + cluster_n_override, + ) + compiled = _PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = RMSNormSM100( + N, + dtype, + stage=stage, + copy_bits=copy_bits, + use_async=use_async, + direct_gmem=direct_gmem, + ) + if tpr_override is not None: + op._tpr_override = tpr_override # type: ignore[attr-defined] + if nt_override is not None: + op._nt_override = nt_override # type: ignore[attr-defined] + if cluster_n_override is not None: + op._cluster_n_override = cluster_n_override # type: ignore[attr-defined] + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_res = rt.make_ptr( + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_w = rt.make_ptr( + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + stream = cuda.CUstream(stream_handle) + ld_x = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs_fused_add_inplace, + ptr_x, + ptr_w, + ptr_res, + Int32(M), + Int32(N), + ld_x, + stream, + Float32(eps), + ) + _PTR_COMPILE_CACHE[key] = compiled + launcher = _get_fast_ptr_fused_add_rmsnorm_launcher( + compiled=compiled, + dtype=dtype, + N=N, + device_index=device_index, + stream_handle=stream_handle, + copy_bits=copy_bits, + use_async=use_async, + tpr=tpr_override or 0, + direct_gmem=direct_gmem, + assumed_align=assumed_align, + eps=eps, + ) + if launcher is not None: + launcher.launch( + x=x, + weight=weight, + residual=residual, + M=M, + N=N, + ld_x=int(x.stride(0)), + eps=eps, + ) + return + + # Fast-launch is disabled/unavailable (or CuTeDSL internals changed). Fall back + # to calling the compiled function directly. + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_res = rt.make_ptr( + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + ptr_w = rt.make_ptr( + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) + stream = cuda.CUstream(stream_handle) + ld_x = Int32(int(x.stride(0))) + compiled(ptr_x, ptr_w, ptr_res, Int32(M), Int32(N), ld_x, stream, Float32(eps)) + + +# ------------------------- +# Public API (forward + verify) +# ------------------------- + + +def rmsnorm_forward( + x: Tensor, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + eps: float = 1e-6, + store_rstd: bool = False, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + assert x.is_cuda + assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." + M, N = x.shape + + # Fast path: use the pointer-based entry whenever we can represent the + # inputs as a row-major [M, N] view with stride(1) == 1 and dtype contracts + # are satisfied (vLLM uses this in inference). + # + # When the pointer path can't be used (e.g. float32 weights for Quack-style + # APIs, or non-standard layouts), fall back to the CuTeDSL stage-2 module + # (ported from `/tmp/oink_main/Blackwell`) before using the slow torch + # reference implementation. + force_stage2 = _FORCE_RMSNORM_STAGE2_FWD + + use_ptr = (not force_stage2) and _can_use_ptr_path(x, weight, bias, residual) + + if use_ptr: + return _rmsnorm_forward_ptr(x, weight, bias, residual, eps, store_rstd) + + # CuTeDSL fallback for cases that aren't safe for the pointer path. + # Import lazily to keep vLLM plugin startup and common inference fast paths + # lightweight. + try: + import importlib + + rms2 = importlib.import_module( + ".rmsnorm_with_stage2", + package=__package__ or "kernelagent_oink.blackwell", + ) + except Exception: + rms2 = None # type: ignore[assignment] + if rms2 is not None: + y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2( + x, + weight=weight, + bias=bias, + residual=residual, + eps=eps, + store_rstd=store_rstd, + ) + # Preserve stride contracts for torch.compile consistency, even + # when using the optional stage-2 implementation. + if y.stride() != x.stride(): + y_strided = torch.empty_strided( + x.shape, x.stride(), device=x.device, dtype=x.dtype + ) + y_strided.copy_(y) + y = y_strided + if residual is not None and residual_out is not None: + if residual_out.stride() != residual.stride(): + residual_out_strided = torch.empty_strided( + residual.shape, + residual.stride(), + device=residual.device, + dtype=residual.dtype, + ) + residual_out_strided.copy_(residual_out) + residual_out = residual_out_strided + return y, rstd, residual_out + + # Safe fallback (correctness-first). This is expected to be rare in vLLM. + y = rmsnorm_ref(x, weight, bias, residual, eps) + # Preserve the input stride contract even on the fallback path so + # torch.compile sees a consistent output layout across all branches. + if y.stride() != x.stride(): + y_strided = torch.empty_strided( + x.shape, x.stride(), device=x.device, dtype=x.dtype + ) + y_strided.copy_(y) + y = y_strided + rstd = None + if store_rstd: + xf = x.float() + if residual is not None: + xf = xf + residual.float() + rstd = torch.rsqrt(xf.square().mean(dim=-1) + eps).to(torch.float32) + residual_out = None + if residual is not None: + residual_out = (x.float() + residual.float()).to(x.dtype) + if residual_out.stride() != residual.stride(): + residual_out_strided = torch.empty_strided( + residual.shape, + residual.stride(), + device=residual.device, + dtype=residual.dtype, + ) + residual_out_strided.copy_(residual_out) + residual_out = residual_out_strided + return y, rstd, residual_out + + +def rmsnorm_ref( + x: Tensor, + w: Optional[Tensor] = None, + b: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + eps: float = 1e-6, +) -> Tensor: + xf = x.float() + if residual is not None: + xf = xf + residual.float() + rstd = torch.rsqrt(xf.square().mean(dim=-1, keepdim=True) + eps) + y = xf * rstd + if w is not None: + y = y * w.float() + if b is not None: + y = y + b.float() + return y.to(x.dtype) + + +def fused_add_rmsnorm_forward( + x: Tensor, + residual: Tensor, + weight: Tensor, + eps: float = 1e-6, +) -> Tuple[Tensor, Tensor]: + """Fused residual-add + RMSNorm for SM100 in CuteDSL. + + This is a convenience wrapper around ``rmsnorm_forward`` that matches the + semantics of vLLM's ``fused_add_rms_norm``: + + z = x + residual + y = RMSNorm(z, weight, eps) + + It returns ``(y, z)`` where ``z`` has the same dtype/shape as the inputs. + """ + assert x.is_cuda and residual.is_cuda + assert x.shape == residual.shape + assert x.dtype == residual.dtype + + orig_shape = x.shape + N = orig_shape[-1] + + x_2d = x.view(-1, N) + res_2d = residual.view(-1, N) + + y_2d, _rstd, z_2d = rmsnorm_forward( + x_2d, + weight=weight, + bias=None, + residual=res_2d, + eps=eps, + store_rstd=False, + ) + + y = y_2d.view(orig_shape) + z = z_2d.view(orig_shape) + return y, z + + +def fused_add_rmsnorm_forward_inplace( + x: Tensor, + residual: Tensor, + weight: Tensor, + eps: float = 1e-6, +) -> Tuple[Tensor, Tensor]: + """In-place fused residual-add + RMSNorm matching vLLM semantics. + + This variant writes: + + z = x + residual (stored into ``residual``) + y = RMSNorm(z, w) (stored into ``x``) + + i.e., it uses ``x`` as the normalized output buffer and ``residual`` as + the residual-out buffer, mirroring vLLM's fused_add_rms_norm kernel. + """ + fused_add_rmsnorm_inplace_(x, residual, weight, eps=eps) + return x, residual + + +def fused_add_rmsnorm_inplace_( + x: Tensor, + residual: Tensor, + weight: Tensor, + eps: float = 1e-6, +) -> None: + """In-place fused residual-add + RMSNorm matching vLLM semantics. + + This is the lowest-overhead Python entrypoint (returns `None`) intended + for performance-critical call sites like `torch.ops.oink.fused_add_rms_norm`. + """ + assert x.is_cuda and residual.is_cuda + assert x.shape == residual.shape + assert x.dtype == residual.dtype + + N = x.shape[-1] + x_2d = x if x.dim() == 2 else x.view(-1, N) + res_2d = residual if residual.dim() == 2 else residual.view(-1, N) + + # Fast path: vLLM-compatible layout where x may be strided/padded but + # residual is contiguous. This updates both tensors in-place without + # additional allocations. + if _can_use_ptr_path_fused_add_inplace(x_2d, weight, res_2d): + _fused_add_rmsnorm_forward_ptr_inplace(x_2d, res_2d, weight, eps) + return None + + # Fallback: allocate via the regular fused path, then copy results into + # the user-provided buffers so that semantics remain identical. + y, z = fused_add_rmsnorm_forward(x, residual, weight, eps) + x.copy_(y) + residual.copy_(z) + return None + + +# ------------------------- +# Backward kernel (SM100) +# ------------------------- + + +class RMSNormBackwardSM100(BaseRMSNormBackward): + """SM100-tuned RMSNorm backward. + + This is a thin wrapper around the generic `lite_quack.RMSNormBackward` + base implementation, with SM100-friendly tiling heuristics that mirror + the forward policy used by Oink. + """ + + def __init__(self, dtype: cutlass.Numeric, N: int): + super().__init__(dtype, N) + + def _get_num_threads(self) -> int: + nt = getattr(self, "_nt_override", None) + if nt is not None: + return int(nt) + return super()._get_num_threads() + + def _calculate_threads_per_row(self) -> int: + tpr = getattr(self, "_tpr_override", None) + if tpr is not None: + return int(tpr) + return super()._calculate_threads_per_row() + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_w: cute.Pointer, + ptr_dout: cute.Pointer, + ptr_rstd: cute.Pointer, + ptr_dx: cute.Pointer, + ptr_dw_partial: cute.Pointer, + M: Int32, + N_dyn: Int32, + ld: Int32, + sm_count: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions. + + This is the performance-critical path used by the benchmark harness + (and any future training integrations) for the common case: + - weight gradient enabled (dw_partial is provided) + - no bias/residual gradients + """ + # Weight-grad stores use vectorized float32 copies. For the SM100 + # schedule we want to allow up to 256b (8x f32) stores, which requires + # the leading dimension to be divisible by 8 to prove 32B alignment for + # every row in `dw_partial`. + N_assumed = cute.assume(N_dyn, divby=8) + + layout_mn = cute.make_layout((M, N_assumed), stride=(ld, 1)) + layout_n = cute.make_layout((N_assumed,), stride=(1,)) + layout_m = cute.make_layout((M,), stride=(1,)) + # Default: write a full (sm_count, N) partial buffer (Quack-style), + # then reduce on the host with `torch.sum(dim=0)`. + # + # Optional: atomic-reduce directly into a single (N,) buffer by using + # a broadcasted leading dimension (stride0 = 0). This avoids the extra + # reduction kernel launch and is primarily used for tiny-M regimes. + if const_expr(self.atomic_dw): + layout_partial = cute.make_layout((sm_count, N_assumed), stride=(0, 1)) + else: + layout_partial = cute.make_layout( + (sm_count, N_assumed), stride=(N_assumed, 1) + ) + + mX = cute.make_tensor(ptr_x, layout_mn) + mW = cute.make_tensor(ptr_w, layout_n) + mdO = cute.make_tensor(ptr_dout, layout_mn) + mRstd = cute.make_tensor(ptr_rstd, layout_m) + mdX = cute.make_tensor(ptr_dx, layout_mn) + mdW = cute.make_tensor(ptr_dw_partial, layout_partial) + + self.__call__( + mX, + mW, + mdO, + None, # dresidual_out + mRstd, + mdX, + mdW, + None, # dresidual + None, # db_partial + sm_count, + stream, + ) + + def _get_num_threads(self) -> int: + # Keep 128 threads only up to N=4k; use 256 for larger rows to ensure + # threads_per_row <= num_threads across buckets. + try: + return self._nt_override # type: ignore[attr-defined] + except Exception: + return 128 if self.N <= 4096 else 256 + + def _calculate_threads_per_row(self) -> int: + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + pass + # Match Quack's backward tiling: use 256 threads/row for N > 4096. + # + # The earlier "mirror forward" policy (128 threads/row for N<=8192) + # regresses DSv3 backward at N=6144/7168/8192 on SM100. + N = self.N + for limit, threads in [(64, 8), (128, 16), (256, 32), (512, 64), (4096, 128)]: + if N <= limit: + return threads + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + return 256 + + def _set_cluster_n(self) -> None: + # Reuse the SM100 forward cluster growth policy so large-N shapes can + # fan out across multiple CTAs in the same row. + try: + self.cluster_n = self._cluster_n_override # type: ignore[attr-defined] + return + except Exception: + pass + + N = self.N + if N <= 8192: + cluster_n = 1 + elif self.dtype.width == 16: + if N <= 16 * 1024: + cluster_n = 2 + elif N <= 32 * 1024: + cluster_n = 2 + elif N <= 64 * 1024: + cluster_n = 4 + elif N <= 128 * 1024: + cluster_n = 8 + else: + cluster_n = 16 + else: + if N <= 32 * 1024: + cluster_n = 1 + elif N <= 64 * 1024: + cluster_n = 2 + elif N <= 128 * 1024: + cluster_n = 4 + elif N <= 256 * 1024: + cluster_n = 8 + else: + cluster_n = 16 + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + sm_count: Int32, + stream: cuda.CUstream, + ): + # Match forward's 32B alignment on the leading dimension to unlock + # wider vectorization when legal. + semistatic_shape = (*mX.shape[:-1], self.N) + + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=256 // t.element_type.width), + t.stride[1], + ) + + mX, mdO, mdResO, mdX, mdRes = [ + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) + if const_expr(t is not None) + else None + for t in (mX, mdO, mdResO, mdX, mdRes) + ] + + self._set_cluster_n() + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mdO.element_type.width, + mdX.element_type.width, + mdResO.element_type.width if mdResO is not None else 0, + mdRes.element_type.width if mdRes is not None else 0, + ) + ) + tiler_mn, tv_layout = self._get_tv_layout( + num_copy_bits=128 // largest_dtype_width * mX.element_type.width + ) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + if const_expr(mW is not None): + mW_expanded_layout = cute.prepend( + mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)) + ) + mW = cute.make_tensor(mW.iterator, mW_expanded_layout) + + num_blocks = sm_count + kernel = ( + self.kernel( + mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes) + ) + kernel.launch( + grid=[num_blocks, self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None, + smem=self._smem_size_in_bytes( + tiler_mn, num_warps, do_dtype=mdO.element_type + ), + stream=stream, + ) + + +_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_BWD_PTR_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _rmsnorm_bwd_sm100( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Optional[Tensor], + db_partial: Optional[Tensor] = None, + dresidual_out: Optional[Tensor] = None, + dresidual: Optional[Tensor] = None, + sm_count: Optional[int] = None, +) -> None: + """SM100-specific RMSNorm backward dispatch. + + Mirrors Quack's `quack.rmsnorm._rmsnorm_bwd`, but instantiates + `RMSNormBackwardSM100` (SM100-tuned heuristics). + """ + assert x.dim() == 2, "Input must be 2D" + assert x.is_cuda, "Input tensor must be on CUDA device" + assert x.dtype in (torch.float16, torch.bfloat16, torch.float32) + + if weight is not None: + assert weight.dim() == 1 + assert x.shape[-1] == weight.shape[0] + assert weight.is_cuda + assert weight.dtype in (torch.float32, torch.bfloat16, torch.float16) + if dresidual_out is not None: + assert dresidual_out.shape == x.shape + assert dresidual_out.is_cuda + assert dresidual_out.dtype in (torch.float16, torch.bfloat16, torch.float32) + if dresidual is not None: + assert dresidual.shape == x.shape + assert dresidual.is_cuda + assert dresidual.dtype in (torch.float16, torch.bfloat16, torch.float32) + + M, N = x.size(0), x.size(1) + if dw_partial is None and db_partial is None: + assert sm_count is not None + else: + sm_count = ( + dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0] + ) + + # Match Quack's conversion strategy for activations/gradients: keep the + # (M, N) layout dynamic without enforcing additional compact-shape + # constraints. This reduces per-call Python overhead for small-M shapes. + def _convert_layout_dynamic(t: Tensor) -> cute.Tensor: + return from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + + x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [ + _convert_layout_dynamic(t) if t is not None else None + for t in (x, dout, dresidual_out, dx, dresidual) + ] + + if weight is not None: + weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] + weight_tensor = convert_from_dlpack_cute( + weight.detach(), + leading_dim=0, + divisibility=128 // weight_dtype.width, + ) + else: + weight_tensor = None + + dw_partial_tensor = ( + from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if dw_partial is not None + else None + ) + db_partial_tensor = ( + from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if db_partial is not None + else None + ) + rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = ( + M, + N, + x_tensor.element_type, + weight_tensor.element_type if weight is not None else None, + db_partial.dtype if db_partial is not None else None, + dresidual.dtype if dresidual is not None else None, + dresidual_out.dtype if dresidual_out is not None else None, + ) + kernel = _BWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = RMSNormBackwardSM100(x_tensor.element_type, N) + + # Shape-specific tuning overrides for DSv3-style N=8192 rows. + if isinstance(op, RMSNormBackwardSM100) and N == 8192: + if M >= 65536: + op._tpr_override = 256 # type: ignore[attr-defined] + op._nt_override = 256 # type: ignore[attr-defined] + elif M >= 16384: + op._tpr_override = 256 # type: ignore[attr-defined] + + kernel = cute.compile( + op, + x_tensor, + weight_tensor, + dout_tensor, + dres_out_tensor, + rstd_tensor, + dx_tensor, + dw_partial_tensor, + dres_tensor, + db_partial_tensor, + Int32(sm_count if sm_count is not None else 0), + current_stream, + ) + _BWD_COMPILE_CACHE[compile_key] = kernel + + kernel( + x_tensor, + weight_tensor, + dout_tensor, + dres_out_tensor, + rstd_tensor, + dx_tensor, + dw_partial_tensor, + dres_tensor, + db_partial_tensor, + Int32(sm_count if sm_count is not None else 0), + current_stream, + ) + + +def _rmsnorm_bwd_sm100_ptr( + x: Tensor, + weight: Tensor, + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Tensor, + sm_count: int, + *, + atomic_dw: bool = False, +) -> None: + """Pointer-based SM100 RMSNorm backward launch (no DLPack conversions). + + When `atomic_dw=True`, `dw_partial` is treated as a single (N,) fp32 buffer + and the kernel atomically accumulates weight gradients into it (avoids the + extra `dw_partial.sum(dim=0)` reduction kernel). + """ + assert _can_use_ptr_path_bwd(x, weight, dout, rstd) + assert dx.shape == x.shape + assert dx.dtype == x.dtype + assert dw_partial.dtype == torch.float32 + + M, N = x.size(0), x.size(1) + if atomic_dw: + assert dw_partial.dim() == 1 and dw_partial.numel() == N + assert dw_partial.is_contiguous() + else: + assert dw_partial.dim() == 2 and dw_partial.shape[1] == N + device_index = x.get_device() + dtype = TORCH2CUTE_DTYPE[x.dtype] + weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] + assumed_align_x = 16 + assumed_align_w = 32 if weight.dtype is torch.float32 else 16 + assumed_align_dw = 32 + assert (dw_partial.data_ptr() % assumed_align_dw) == 0 + + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + ld_val = int(x.stride(0)) + key = ( + "bwd_ptr", + N, + dtype, + weight_dtype, + int(assumed_align_x), + int(assumed_align_w), + int(assumed_align_dw), + device_index, + bool(atomic_dw), + ) + compiled = _BWD_PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = RMSNormBackwardSM100(dtype, N) + op.atomic_dw = bool(atomic_dw) + # 16-bit activations + 16-bit weights (vLLM-style) backward at N=4096: + # Use a 1-row/CTA schedule with 256 threads/row. This reduces per-thread + # work and improves bandwidth on large-M shapes on SM100. + if ( + (not atomic_dw) + and N == 4096 + and dtype.width == 16 + and weight_dtype.width == 16 + ): + op._tpr_override = 256 # type: ignore[attr-defined] + op._nt_override = 256 # type: ignore[attr-defined] + # 16-bit activations + fp32 weights backward at N=4096: + # Use a 256-thread schedule (tpr=256) to improve throughput. + if ( + (not atomic_dw) + and N == 4096 + and dtype.width == 16 + and weight_dtype is cutlass.Float32 + ): + op._tpr_override = 256 # type: ignore[attr-defined] + op._nt_override = 256 # type: ignore[attr-defined] + # FP16 + fp32-weight DSv3 backward: Quack's default (1 row/CTA with + # 256 threads/row) underperforms. Use a 2-rows/CTA schedule (256 threads + # total, 128 threads/row) to improve memory-level parallelism. + if ( + (not atomic_dw) + and N == 6144 + and dtype is cutlass.Float16 + and weight_dtype is cutlass.Float32 + ): + op._tpr_override = 128 # type: ignore[attr-defined] + op._nt_override = 256 # type: ignore[attr-defined] + + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_w = rt.make_ptr( + weight_dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + ptr_dout = rt.make_ptr( + dtype, + dout.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dx = rt.make_ptr( + dtype, + dx.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dw = rt.make_ptr( + cutlass.Float32, + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw, + Int32(M), + Int32(N), + Int32(ld_val), + Int32(int(sm_count)), + stream, + ) + _BWD_PTR_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_rmsnorm_bwd_launcher( + compiled=compiled, + dtype=dtype, + weight_dtype=weight_dtype, + N=N, + device_index=device_index, + stream_handle=stream_handle, + has_weight=True, + has_dw_partial=True, + assumed_align_x=assumed_align_x, + assumed_align_w=assumed_align_w, + assumed_align_dw=assumed_align_dw, + ) + if launcher is not None: + launcher.launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld_val, + sm_count=int(sm_count), + ) + return + + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_w = rt.make_ptr( + weight_dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + ptr_dout = rt.make_ptr( + dtype, + dout.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dx = rt.make_ptr( + dtype, + dx.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dw = rt.make_ptr( + cutlass.Float32, + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + compiled( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw, + Int32(M), + Int32(N), + Int32(ld_val), + Int32(int(sm_count)), + stream, + ) + + +def rmsnorm_backward( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dresidual_out: Optional[Tensor] = None, + has_bias: bool = False, + has_residual: bool = False, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + """Public SM100 RMSNorm backward entry point. + + Signature mirrors `quack.rmsnorm.rmsnorm_bwd` for easy comparisons. + """ + device = x.device + M, N = x.size(0), x.size(1) + dx = torch.empty_like(x) + if dresidual_out is not None and dresidual_out.dtype != dx.dtype: + dresidual = torch.empty_like(x, dtype=dresidual_out.dtype) + else: + dresidual = None + + # Shared SM100 tuning policy (used by both RMSNorm and LayerNorm). + sm_count = get_sm_count(N, device, M=M, dtype=x.dtype) + + # Quack-suite smallest case (M=8192, N=4096) is extremely sensitive to + # Python/allocator overhead because the kernel itself is very fast. + # + # The default `lite_quack.get_sm_count` adds a small-M occupancy boost for + # N=4096, which increases `dw_partial` size and can amplify allocator + # pressure in benchmark/verify loops. Clamp to Quack's baseline policy + # (`sm_count = num_sms * 2` for N=4096) for this regime. + if N == 4096 and M <= 8192 and x.dtype in (torch.float16, torch.bfloat16): + num_sms = qutils.get_num_sms(device) + sm_count = min(int(sm_count), int(num_sms) * 2) + + use_atomic_dw = False + # DSv3 backward (N=6144/7168/8192) is dominated by the (sm_count, N) partial + # write + reduction for dW. Use the atomic-dW path to accumulate directly + # into a single (N,) fp32 buffer (no separate reduction kernel). + if ( + weight is not None + and (not has_bias) + and (not has_residual) + and dresidual_out is None + and dresidual is None + and N == 8192 + and weight.dtype is torch.float32 + and M >= 65536 + and x.dtype in (torch.float16, torch.bfloat16) + and _can_use_ptr_path_bwd(x, weight, dout, rstd) + ): + use_atomic_dw = True + + if weight is not None: + if use_atomic_dw: + dw_partial = torch.zeros(N, device=device, dtype=torch.float32) + else: + dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) + else: + dw_partial = None + db_partial = ( + torch.empty(sm_count, N, device=device, dtype=torch.float32) + if has_bias + else None + ) + + if ( + weight is not None + and dw_partial is not None + and (not has_bias) + and (not has_residual) + and dresidual_out is None + and dresidual is None + and _can_use_ptr_path_bwd(x, weight, dout, rstd) + ): + _rmsnorm_bwd_sm100_ptr( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + sm_count=int(sm_count), + atomic_dw=bool(use_atomic_dw), + ) + else: + _rmsnorm_bwd_sm100( + x, + weight, + dout, + rstd, + dx, + dw_partial, + db_partial, + dresidual_out, + dresidual, + sm_count, + ) + + if weight is not None and dw_partial is not None: + if use_atomic_dw: + dw_fp32 = dw_partial + else: + dw_fp32 = _reduce_partial_sum_fp32(dw_partial, device_index=x.get_device()) + dw = dw_fp32 if weight.dtype is torch.float32 else dw_fp32.to(weight.dtype) + else: + dw = None + db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None + if has_residual and dresidual is None: + dresidual = dx + return dx, dw, db, dresidual + + +# Quack-style alias for benchmarks +rmsnorm_bwd = rmsnorm_backward + + +if __name__ == "__main__": + # Minimal ad-hoc test (functionality only). For performance comparisons, use the benchmark harness. + if not torch.cuda.is_available(): + print("CUDA not available; functional test skipped.") + sys.exit(0) + M, N = 1024, 8192 + dtype = torch.bfloat16 + x = torch.randn(M, N, device="cuda", dtype=dtype) + w = torch.randn(N, device="cuda", dtype=dtype) + y_ref = rmsnorm_ref(x, w) + y, _, _ = rmsnorm_forward(x, w) + torch.testing.assert_close(y, y_ref, rtol=1e-3, atol=1e-3) + print("RMSNormSM100 correctness check passed.") + +# (compile cache moved to top) diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py new file mode 100644 index 0000000..2b5b36d --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py @@ -0,0 +1,1007 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RMSNorm kernel for SM100 (Blackwell) in CuteDSL, with the experimental +stage-2 cp.async ping-pong path preserved for N≈6k/8k. + +This file is a fork of rmsnorm.py that keeps the K-loop cp.async path +behind `self.stage > 1` while the main implementation has been simplified +to a single-stage schedule. +""" + +from __future__ import annotations + +import importlib.metadata +import re +from typing import Optional, Tuple + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, const_expr +from cutlass.cute.runtime import from_dlpack + +from kernelagent_oink.blackwell import lite_quack as qutils +from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce + +_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _parse_version_tuple(version: str) -> tuple[int, int, int]: + parts = version.split(".") + nums: list[int] = [] + for part in parts[:3]: + match = re.match(r"^(\d+)", part) + nums.append(int(match.group(1)) if match is not None else 0) + while len(nums) < 3: + nums.append(0) + return nums[0], nums[1], nums[2] + + +def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]: + try: + return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl")) + except Exception: + return None + + +_CUTLASS_DSL_VERSION = _cutlass_dsl_version() +# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around +# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the +# older signature for <4.3.4, but switch to a 4.3.4+ compatible signature when +# we detect 4.3.4+ (or when version detection is unavailable). +_KERNEL_ACCEPTS_LAYOUT_ARGS = ( + _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) +) + + +@cute.jit +def get_copy_atom_bw( + dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False +) -> cute.CopyAtom: + max_bits = const_expr(128 if is_async else 256) + num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width)) + from cutlass.cute.nvgpu import cpasync + + copy_op = ( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL) + if is_async + else cute.nvgpu.CopyUniversalOp() + ) + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@cute.jit +def copy_tiled( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, +) -> None: + atom = get_copy_atom_bw(src.element_type, num_copy_elems, is_async) + cute.copy(atom, src, dst, pred=pred) + + +class RMSNormSM100WithStage2: + def __init__( + self, N: int, dtype: type[cutlass.Numeric], stage: Optional[int] = None + ): + self.N = N + self.dtype = dtype + self.stage = 1 if stage is None else stage + self.reduction_dtype = cutlass.Float32 + + def _threads_per_row(self) -> int: + N = self.N + if N <= 64: + return 8 + elif N <= 128: + return 16 + elif N <= 1024: + return 32 + elif N <= 4096: + return 128 + elif N <= 8192: + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + return 128 + elif N <= 16384: + return 256 + else: + return 256 + + def _cluster_n(self) -> int: + N = self.N + if N <= 8192: + return 1 + if const_expr(self.dtype.width == 16): + if N <= 16 * 1024: + return 2 + elif N <= 32 * 1024: + return 2 + elif N <= 64 * 1024: + return 4 + elif N <= 128 * 1024: + return 8 + else: + return 16 + else: + if N <= 32 * 1024: + return 1 + elif N <= 64 * 1024: + return 2 + elif N <= 128 * 1024: + return 4 + elif N <= 256 * 1024: + return 8 + else: + return 16 + + def _num_threads(self) -> int: + try: + return self._nt_override # type: ignore[attr-defined] + except Exception: + return 128 if self.N <= 16384 else 256 + + def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]: + vecsize = num_copy_bits // self.dtype.width + num_threads = self._num_threads() + assert num_threads % cute.arch.WARP_SIZE == 0 + tpr = self._threads_per_row() + cluster_n = self._cluster_n() + num_cols_vec = cute.ceil_div(self.N, vecsize) + num_blocks_N = cute.ceil_div(num_cols_vec, tpr * cluster_n) + cols_per_block = num_threads // tpr + tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr) + tv_layout = cute.make_layout( + ((tpr, cols_per_block), (vecsize, num_blocks_N)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * tpr), + ), + ) + return tiler_mn, tv_layout + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + semistatic_shape = (*mX.shape[:-1], self.N) + + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=256 // t.element_type.width), + t.stride[1], + ) + + mX, mRes, mO, mResO = [ + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) + if const_expr(t is not None) + else None + for t in (mX, mRes, mO, mResO) + ] + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + + copy_bits = const_expr(128) + tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + threads_per_row = ( + tv_layout.shape[0][0] + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._threads_per_row() + ) + warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) + cluster_n = self._cluster_n() + + if const_expr(mW is not None): + mW = cute.make_tensor( + mW.iterator, + cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mB is not None): + mB = cute.make_tensor( + mB.iterator, + cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mRstd is not None): + mRstd = cute.make_tensor( + mRstd.iterator, + cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))), + ) + + stage_bufs = 2 if self.stage > 1 else 1 + tile_bytes_x = ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs + ) + tile_bytes_res = ( + cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) + * stage_bufs + if const_expr(mRes is not None) + else 0 + ) + red_bytes = ( + self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8) + ) + mbar_bytes = self.stage * (cutlass.Int64.width // 8) + smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes + mbar_bytes + + kernel = ( + self.kernel( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + const_expr(num_warps), + const_expr(warps_per_row), + const_expr(threads_per_row), + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + ) + ) + + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1], + block=[num_threads, 1, 1], + cluster=([1, cluster_n, 1] if const_expr(cluster_n > 1) else None), + smem=smem_bytes, + stream=stream, + ) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_n = self._cluster_n() + cluster_y = ( + const_expr(0) if const_expr(cluster_n == 1) else cute.arch.block_idx()[1] + ) + + smem = cutlass.utils.SmemAllocator() + sX0 = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + sX1 = ( + smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(self.stage > 1) + else None + ) + sRes0 = ( + smem.allocate_tensor( + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(mRes is not None) + else None + ) + sRes1 = ( + smem.allocate_tensor( + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(mRes is not None and self.stage > 1) + else None + ) + + reduction_buffer, mbar_ptr = self._alloc_reduction_and_mbar( + smem, num_warps, warps_per_row + ) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + num_copy_elems_X = tv_layout.shape[1][0] + use_async = const_expr(self.N >= 1024) + copy_atom = get_copy_atom_bw( + mX.element_type, num_copy_elems_X, is_async=use_async + ) + thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx) + + gW, gB = [ + cute.local_tile(t, tiler_mn, (0, cluster_y)) + if const_expr(t is not None) + else None + for t in (mW, mB) + ] + tXgW = thr_copy.partition_S(gW) if const_expr(mW is not None) else None + tXgB = thr_copy.partition_S(gB) if const_expr(mB is not None) else None + tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None + tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None + if const_expr(mW is not None): + cute.copy( + get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), + tXgW, + tXrW, + ) + if const_expr(mB is not None): + cute.copy( + get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), + tXgB, + tXrB, + ) + + self._init_cluster(tidx, mbar_ptr) + + mX_i, mRes_i, mO_i, mResO_i = [ + qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) + if t is not None + else None + for t in (mX, mRes, mO, mResO) + ] + gX_i = cute.local_tile(mX_i, tiler_mn, (0, cluster_y)) + gO_i = cute.local_tile(mO_i, tiler_mn, (0, cluster_y)) + gRes_i = ( + cute.local_tile(mRes_i, tiler_mn, (0, cluster_y)) + if const_expr(mRes is not None) + else None + ) + gResO_i = ( + cute.local_tile(mResO_i, tiler_mn, (0, cluster_y)) + if const_expr(mResO is not None) + else None + ) + gRstd_i = ( + cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y)) + if const_expr(mRstd is not None) + else None + ) + cX_i = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + + tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None] + row_i = tXcX_i[0][0] + tXgRstd_i = ( + thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None + ) + + # Intra-row K-loop cp.async ping-pong (two-pass) for N≈6k/8k (stage=2) + if const_expr(self.stage > 1 and (shape[1] == 6144 or shape[1] == 8192)): + vecsize = tv_layout.shape[1][0] + tpr = threads_per_row + target_tile_n = const_expr(4096 if shape[1] == 6144 else 8192) + tile_factor = const_expr(target_tile_n // (vecsize * tpr)) + tile_n = vecsize * tpr * tile_factor + num_tiles = cute.ceil_div(shape[1], tile_n) + + tiler_mn_tile = (tiler_mn[0], tile_n) + sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0)) + sX1_tile = ( + cute.local_tile(sX1, tiler_mn_tile, (0, 0)) + if const_expr(self.stage > 1) + else None + ) + sRes0_tile = ( + cute.local_tile(sRes0, tiler_mn_tile, (0, 0)) + if const_expr(mRes is not None) + else None + ) + sRes1_tile = ( + cute.local_tile(sRes1, tiler_mn_tile, (0, 0)) + if const_expr(mRes is not None and self.stage > 1) + else None + ) + + tv_layout_tile = cute.make_layout( + ((tpr, tiler_mn[0]), (vecsize, tile_factor)), + stride=( + (vecsize * tiler_mn[0], 1), + (tiler_mn[0], tiler_mn[0] * vecsize * tpr), + ), + ) + thr_copy_tile = cute.make_tiled_copy( + copy_atom, tv_layout_tile, tiler_mn_tile + ).get_slice(tidx) + + sum_sq_acc = cute.Float32(0.0) + k_off0 = const_expr(0) * tile_n + gX_0 = cute.local_tile( + qutils.domain_offset_i64((0, k_off0), mX_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgX_0 = thr_copy_tile.partition_S(gX_0) + tXsX_0 = thr_copy_tile.partition_D(sX0_tile) + cX_0 = cute.local_tile( + cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, cluster_y) + ) + tXc_0 = thr_copy_tile.partition_S(cX_0) + tXp_0 = qutils.predicate_k(tXc_0, limit=shape[1]) + tXp_ping = tXp_0 + tXp_pong = tXp_0 + if row_i < shape[0]: + copy_tiled( + tXgX_0, + tXsX_0, + num_copy_elems=vecsize, + is_async=use_async, + pred=tXp_0, + ) + if const_expr(mRes is not None): + gRes_0 = cute.local_tile( + qutils.domain_offset_i64((0, k_off0), mRes_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgRes_0 = thr_copy_tile.partition_S(gRes_0) + tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile) + copy_tiled( + tXgRes_0, + tXsRes_0, + num_copy_elems=vecsize, + is_async=use_async, + pred=tXp_0, + ) + if const_expr(use_async): + cute.arch.cp_async_commit_group() + + for t in cutlass.range_constexpr(num_tiles): + next_t = t + 1 + if next_t < num_tiles: + k_off_n = next_t * tile_n + gX_n = cute.local_tile( + qutils.domain_offset_i64((0, k_off_n), mX_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgX_n = thr_copy_tile.partition_S(gX_n) + cX_n = cute.local_tile( + cute.domain_offset((0, k_off_n), cX_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXc_n = thr_copy_tile.partition_S(cX_n) + tXp_n = qutils.predicate_k(tXc_n, limit=shape[1]) + if const_expr((t % 2) == 0): + tXsX_n = thr_copy_tile.partition_D(sX1_tile) + tXsRes_n = ( + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None + ) + tXp_pong = tXp_n + else: + tXsX_n = thr_copy_tile.partition_D(sX0_tile) + tXsRes_n = ( + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None + ) + tXp_ping = tXp_n + if row_i < shape[0]: + copy_tiled( + tXgX_n, + tXsX_n, + num_copy_elems=vecsize, + is_async=use_async, + pred=tXp_n, + ) + if const_expr(mRes is not None): + gRes_n = cute.local_tile( + qutils.domain_offset_i64((0, k_off_n), mRes_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgRes_n = thr_copy_tile.partition_S(gRes_n) + copy_tiled( + tXgRes_n, + tXsRes_n, + num_copy_elems=vecsize, + is_async=use_async, + pred=tXp_n, + ) + if const_expr(use_async): + cute.arch.cp_async_commit_group() + if const_expr(use_async): + cute.arch.cp_async_wait_group(1 if next_t < num_tiles else 0) + + if const_expr((t % 2) == 0): + tXsX_cur = thr_copy_tile.partition_D(sX0_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None + ) + pred_cur = tXp_ping + else: + tXsX_cur = thr_copy_tile.partition_D(sX1_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None + ) + pred_cur = tXp_pong + qutils.fill_oob(tXsX_cur, pred_cur, mX.element_type.zero) + if const_expr(mRes is not None): + qutils.fill_oob(tXsRes_cur, pred_cur, mRes.element_type.zero) + + k_off = t * tile_n + gX_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mX_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgX_t = thr_copy_tile.partition_S(gX_t) + tXrX = cute.make_fragment_like(tXgX_t) + cute.autovec_copy(tXsX_cur, tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + gRes_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mRes_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgRes_t = thr_copy_tile.partition_S(gRes_t) + tXrRes = cute.make_fragment_like(tXgRes_t) + cute.autovec_copy(tXsRes_cur, tXrRes) + x += tXrRes.load().to(cute.Float32) + + if const_expr(mResO is not None): + gResO_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mResO_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgResO_t = thr_copy_tile.partition_D(gResO_t) + tXrResO = cute.make_fragment_like(tXgResO_t) + tXrResO.store(x.to(tXrResO.element_type)) + if row_i < shape[0]: + copy_tiled( + tXrResO, + tXgResO_t, + num_copy_elems=vecsize, + is_async=False, + pred=pred_cur, + ) + + sum_sq_tile = row_reduce( + x * x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + hook_fn=( + cute.arch.cluster_wait if const_expr(cluster_n > 1) else None + ), + ) + sum_sq_acc = sum_sq_acc + sum_sq_tile + + rstd = cute.math.rsqrt(sum_sq_acc / shape[1] + eps, fastmath=True) + if const_expr(mRstd is not None): + if ( + tXcX_i[0][1] == 0 + and row_i < shape[0] + and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXgRstd_i[0] = rstd + + for t in cutlass.range_constexpr(num_tiles): + k_off = t * tile_n + cX_t = cute.local_tile( + cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, cluster_y) + ) + tXc_t = thr_copy_tile.partition_S(cX_t) + tXp_t = qutils.predicate_k(tXc_t, limit=shape[1]) + + if const_expr((t % 2) == 0): + tXsX_cur = thr_copy_tile.partition_D(sX0_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None + ) + else: + tXsX_cur = thr_copy_tile.partition_D(sX1_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None + ) + + qutils.fill_oob(tXsX_cur, tXp_t, mX.element_type.zero) + if const_expr(mRes is not None): + qutils.fill_oob(tXsRes_cur, tXp_t, mRes.element_type.zero) + + gX_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mX_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgX_t = thr_copy_tile.partition_S(gX_t) + tXrX = cute.make_fragment_like(tXgX_t) + cute.autovec_copy(tXsX_cur, tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + gRes_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mRes_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgRes_t = thr_copy_tile.partition_S(gRes_t) + tXrRes = cute.make_fragment_like(tXgRes_t) + cute.autovec_copy(tXsRes_cur, tXrRes) + x += tXrRes.load().to(cute.Float32) + + y = x * rstd + if const_expr(mW is not None): + gW_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mW), + tiler_mn_tile, + (0, cluster_y), + ) + tWgW_t = thr_copy_tile.partition_S(gW_t) + tWrW_t = cute.make_fragment_like(tWgW_t) + copy_tiled( + tWgW_t, + tWrW_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) + y = y * tWrW_t.load().to(cute.Float32) + if const_expr(mB is not None): + gB_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mB), + tiler_mn_tile, + (0, cluster_y), + ) + tWgB_t = thr_copy_tile.partition_S(gB_t) + tWrB_t = cute.make_fragment_like(tWgB_t) + copy_tiled( + tWgB_t, + tWrB_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) + y = y + tWrB_t.load().to(cute.Float32) + + gO_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mO_i), + tiler_mn_tile, + (0, cluster_y), + ) + tXgO_t = thr_copy_tile.partition_D(gO_t) + tXrO = cute.make_fragment_like(tXgO_t) + tXrO.store(y.to(tXrO.element_type)) + if row_i < shape[0]: + copy_tiled( + tXrO, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t + ) + + return + + # Fallback: single-stage path identical to current rmsnorm.py + tXgX_i = thr_copy.partition_S(gX_i) + tXgRes_i = ( + thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None + ) + tXgO_i = thr_copy.partition_D(gO_i) + tXgResO_i = ( + thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None + ) + is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n) + tXpX_i = ( + qutils.predicate_k(thr_copy.partition_S(cX_i), limit=shape[1]) + if not is_even_N_i + else None + ) + + if row_i < shape[0]: + cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i) + if const_expr(mRes is not None): + cute.copy(copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i) + if const_expr(use_async): + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + tXrX = cute.make_fragment_like(tXgX_i) + cute.autovec_copy(thr_copy.partition_D(sX0), tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + tXrRes = cute.make_fragment_like(tXgRes_i) + cute.autovec_copy(thr_copy.partition_D(sRes0), tXrRes) + x += tXrRes.load().to(cute.Float32) + + if const_expr(mResO is not None): + tXrResO = cute.make_fragment_like(tXgResO_i) + tXrResO.store(x.to(tXrResO.element_type)) + if row_i < shape[0]: + cute.copy( + get_copy_atom_bw( + tXrResO.element_type, num_copy_elems_X, is_async=False + ), + tXrResO, + tXgResO_i, + ) + + sum_sq = row_reduce( + x * x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + hook_fn=(cute.arch.cluster_wait if const_expr(cluster_n > 1) else None), + ) + rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True) + + if const_expr(mRstd is not None): + if ( + tXcX_i[0][1] == 0 + and row_i < shape[0] + and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXgRstd_i[0] = rstd + + y = x * rstd + if const_expr(mW is not None): + y = y * tXrW.load().to(cute.Float32) + if const_expr(mB is not None): + y = y + tXrB.load().to(cute.Float32) + + tXrO = cute.make_fragment_like(tXgO_i) + tXrO.store(y.to(tXrO.element_type)) + if row_i < shape[0]: + cute.copy( + get_copy_atom_bw(tXrO.element_type, num_copy_elems_X, is_async=False), + tXrO, + tXgO_i, + ) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + threads_per_row: cutlass.Constexpr[int], + ): + self._kernel_impl( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + num_warps, + warps_per_row, + threads_per_row, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + ): + copy_bits = const_expr(128) + tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) + num_threads = self._num_threads() + num_warps = num_threads // cute.arch.WARP_SIZE + threads_per_row = self._threads_per_row() + warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) + self._kernel_impl( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + const_expr(num_warps), + const_expr(warps_per_row), + const_expr(threads_per_row), + ) + + @cute.jit + def _alloc_reduction_and_mbar( + self, + smem: cutlass.utils.SmemAllocator, + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]: + cluster_n = self._cluster_n() + red_layout = cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), + order=(1, 0, 2), + ) + reduction_buffer = smem.allocate_tensor( + self.reduction_dtype, red_layout, byte_alignment=4 + ) + if const_expr(cluster_n > 1): + mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage) + else: + mbar_ptr = None + return reduction_buffer, mbar_ptr + + @cute.jit + def _init_cluster(self, tidx: cutlass.Int32, mbar_ptr: Optional[cute.Pointer]): + if const_expr(mbar_ptr is not None): + if tidx < self.stage: + cute.arch.mbarrier_init(mbar_ptr + tidx, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + +def rmsnorm_forward_with_stage2( + x: Tensor, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + eps: float = 1e-6, + store_rstd: bool = False, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + assert x.is_cuda + assert x.dim() == 2 + M, N = x.shape + dtype = TORCH2CUTE_DTYPE[x.dtype] + + def _convert_x(t: Tensor) -> cute.Tensor: + return from_dlpack(t.detach(), assumed_align=32).mark_layout_dynamic( + leading_dim=1 + ) + + mX = _convert_x(x) + mRes = _convert_x(residual) if residual is not None else None + out = torch.empty_like(x, dtype=x.dtype) + mO = from_dlpack(out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1) + + mW = ( + from_dlpack(weight.detach(), assumed_align=32).mark_layout_dynamic( + leading_dim=0 + ) + if weight is not None + else None + ) + mB = ( + from_dlpack(bias.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=0) + if bias is not None + else None + ) + if store_rstd: + rstd = torch.empty(M, device=x.device, dtype=torch.float32) + mRstd = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + else: + rstd = None + mRstd = None + + residual_out = None + mResO = None + if residual is not None: + residual_out = torch.empty_like(residual) + mResO = from_dlpack( + residual_out.detach(), assumed_align=32 + ).mark_layout_dynamic(leading_dim=1) + + # Enable the intra-row cp.async K-loop only for DSv3-style large-N rows + # with very large M, where there is enough work per row to amortize the + # pipeline start-up cost. Mid-size M shapes are better served by the + # simpler single-stage schedule. + use_kloop = bool(M >= 65536 and N in (6144, 8192)) + stage = 2 if use_kloop else 1 + op = RMSNormSM100WithStage2(N, dtype, stage=stage) + if use_kloop: + op._tpr_override = 128 # type: ignore[attr-defined] + # Prefer 1 row/CTA at N=6144; keep 2 rows/CTA at N=8192 to match + # the original tuning there. + op._nt_override = 128 if N == 6144 else 256 # type: ignore[attr-defined] + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + key = ( + N, + dtype, + mRes is not None, + mW is not None, + mB is not None, + mResO is not None, + mRstd is not None, + stage, + ) + compiled = _COMPILE_CACHE.get(key) + if compiled is None: + compiled = cute.compile( + op, mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps) + ) + _COMPILE_CACHE[key] = compiled + compiled(mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps)) + return out, rstd, residual_out diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py new file mode 100644 index 0000000..394ab48 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/softmax.py @@ -0,0 +1,1582 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Softmax forward + backward kernels for SM100 (Blackwell) in CuteDSL. + +This module implements numerically stable softmax over the last dimension of +2D tensors (M, N) and its backward pass, targeting SM100 with Quack-style +tiling, cp.async pipelines, and cluster reductions, but without depending on +the `quack` package at runtime. + +The kernels are self-contained and use only local helpers in +`kernelagent_oink.blackwell.lite_quack` plus CuTeDSL/CUTLASS. +""" + +from __future__ import annotations + +import importlib.metadata +import os +import re +from typing import Type + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +# CuTeDSL caches generated MLIR into a tempdir under a global default +# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across +# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy +# warnings (and disables cache reuse). +if "CUTE_DSL_CACHE_DIR" not in os.environ: + try: + _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + _dsl_ver = "unknown" + _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) + _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + _tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" + ) + +try: + import cutlass # type: ignore # noqa: F401 +except Exception as e: + raise ImportError( + "kernelagent_oink.blackwell.softmax requires CuTeDSL's Python package " + "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." + ) from e + +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import runtime as rt +from cutlass.cute.runtime import from_dlpack + +from kernelagent_oink.blackwell.fast_launch import ( + StableI32Arg, + disable_fast_launch, + fast_launch_enabled, + set_runtime_ptr, + tls_cache as _tls_fast_launch_cache, +) +from kernelagent_oink.blackwell.lite_quack import ( + _KERNEL_ACCEPTS_LAYOUT_ARGS, + TORCH2CUTE_DTYPE, + ReductionBase, + fill_oob, + online_softmax_reduce, + predicate_k, + row_reduce, +) + +_FWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {} +_BWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {} +_PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_PTR_FWDBWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +class _PtrSoftmaxFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_a: object, + ptr_b: object, + ptr_c: object | None, + arg_m: StableI32Arg, + arg_ld: StableI32Arg, + stream: cuda.CUstream, + assumed_align: int, + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_a = ptr_a + self._ptr_b = ptr_b + self._ptr_c = ptr_c + self._arg_m = arg_m + self._arg_ld = arg_ld + self._stream = stream + self._assumed_align = int(assumed_align) + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_a_ptr = -1 + self._last_b_ptr = -1 + self._last_c_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + + def launch( + self, + *, + a_ptr: int, + b_ptr: int, + c_ptr: int | None, + M: int, + ld: int, + stream_handle: int, + dtype: type[cutlass.Numeric], + ) -> None: + if not fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + M=M, + ld=ld, + stream_handle=stream_handle, + dtype=dtype, + ) + return + + if a_ptr != self._last_a_ptr: + try: + set_runtime_ptr(self._ptr_a, a_ptr) + self._last_a_ptr = a_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + M=M, + ld=ld, + stream_handle=stream_handle, + dtype=dtype, + ) + return + + if b_ptr != self._last_b_ptr: + try: + set_runtime_ptr(self._ptr_b, b_ptr) + self._last_b_ptr = b_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + M=M, + ld=ld, + stream_handle=stream_handle, + dtype=dtype, + ) + return + + if self._ptr_c is not None and c_ptr is not None: + if c_ptr != self._last_c_ptr: + try: + set_runtime_ptr(self._ptr_c, c_ptr) + self._last_c_ptr = c_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + M=M, + ld=ld, + stream_handle=stream_handle, + dtype=dtype, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + + if self._cuda_result is not None: + self._cuda_result.value = 0 + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + self._use_fast_launch = False + disable_fast_launch() + + def _fallback_launch( + self, + *, + a_ptr: int, + b_ptr: int, + c_ptr: int | None, + M: int, + ld: int, + stream_handle: int, + dtype: type[cutlass.Numeric], + ) -> None: + stream = cuda.CUstream(int(stream_handle)) + ptr_a = rt.make_ptr( + dtype, + a_ptr, + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + ptr_b = rt.make_ptr( + dtype, + b_ptr, + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + if self._ptr_c is not None and c_ptr is not None: + ptr_c = rt.make_ptr( + dtype, + c_ptr, + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + self._compiled(ptr_a, ptr_b, ptr_c, Int32(int(M)), Int32(int(ld)), stream) + else: + self._compiled(ptr_a, ptr_b, Int32(int(M)), Int32(int(ld)), stream) + + +def _get_fast_ptr_softmax_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + N: int, + device_index: int, + stream_handle: int, + assumed_align: int, + is_bwd: bool, +) -> _PtrSoftmaxFastLaunch | None: + if not fast_launch_enabled(): + return None + key = ( + "ptr_fast_bwd" if is_bwd else "ptr_fast_fwd", + id(compiled), + int(N), + dtype, + int(device_index), + int(stream_handle), + int(assumed_align), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + assumed_align = int(assumed_align) + ptr_a = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_b = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_c = ( + rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + if is_bwd + else None + ) + + arg_m = StableI32Arg(0) + arg_ld = StableI32Arg(N) + stream = cuda.CUstream(int(stream_handle)) + executor = compiled.to(device_index) # type: ignore[attr-defined] + try: + if ptr_c is not None: + exe_args, adapted_args = executor.generate_execution_args( + ptr_a, + ptr_b, + ptr_c, + arg_m, + arg_ld, + stream, + ) + else: + exe_args, adapted_args = executor.generate_execution_args( + ptr_a, + ptr_b, + arg_m, + arg_ld, + stream, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + disable_fast_launch() + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_a, + ptr_b, + ptr_c, + arg_m, + arg_ld, + stream, + *adapted_args, + ) + launcher = _PtrSoftmaxFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_a=ptr_a, + ptr_b=ptr_b, + ptr_c=ptr_c, + arg_m=arg_m, + arg_ld=arg_ld, + stream=stream, + assumed_align=assumed_align, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + +class SoftmaxFwdSM100(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # One-stage online reduction: pack (max, sum_exp) into Int64 reduction buffer. + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + + def _get_num_threads(self) -> int: + # SM100 tuning note: + # For N=4096, we use 32 threads per row (1 warp) and run 1 row per CTA + # (32 threads total). This keeps the reduction fully warp-local and + # improves throughput on this GB200 versus Quack's default 2-rows-per-CTA + # schedule with 64 threads per row (4 warps total). + if self.N == 4096: + return 32 + return super()._get_num_threads() + + def _calculate_threads_per_row(self) -> int: + # Match Quack's bucketed policy for Softmax. + N = self.N + if N == 4096: + return 32 + if N == 6144: + return 128 + if N <= 64: + return 8 + if N <= 128: + return 16 + if N <= 3072: + return 32 + if N <= 6144: + return 64 + if N <= 16384: + return 128 + return 256 + + def _set_cluster_n(self) -> None: + # Quack-style growth of cluster_n with N and dtype. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__(self, mX: cute.Tensor, mO: cute.Tensor, stream: cuda.CUstream) -> None: + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + # Use the generic ReductionBase tiling with 128-bit vectorization. + tiler_mn, tv_layout = self._get_tv_layout() + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel(mX, mO, tv_layout, tiler_mn) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mX, mO) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_out: cute.Pointer, + M: Int32, + ld: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions. + + Reconstructs cute.Tensor views from raw pointers + explicit layouts + inside the JIT graph, matching the existing SM100 schedule. + """ + # Mirror Quack/LayerNorm contracts: assume 16B alignment and an LD that + # preserves 128-bit vectorized copies for every row start. + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + mX = cute.make_tensor(ptr_x, layout_mn) + mO = cute.make_tensor(ptr_out, layout_mn) + self.__call__(mX, mO, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mO: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # Quack-style CTA tiling. + gX, gO, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX) + ] + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) + + # Copy atoms for gmem <-> smem and smem <-> gmem. + # Use 128-bit cp.async for global->shared and 128-bit vectorized stores. + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=128, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gO.element_type, + num_bits_per_copy=128, + ) + + num_copy_elems = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems)) + thr_copy_load = cute.make_tiled_copy_tv( + copy_atom_load, thr_layout, val_layout + ).get_slice(tidx) + thr_copy_store = cute.make_tiled_copy_tv( + copy_atom_store, thr_layout, val_layout + ).get_slice(tidx) + + tXgX = thr_copy_load.partition_S(gX) + tXsX = thr_copy_load.partition_D(sX) + tXgO = thr_copy_store.partition_D(gO) + tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None] + + # Register fragments. + tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)] + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + # Predicate and cp.async pipeline for potential tail tiles. + is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_load.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + # Online softmax reduction: compute max and sum_exp in a single pass, with + # optional cluster-wide aggregation via an Int64 reduction buffer. + max_x, denom, exp_x = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + phase=None, + return_exp_x=True, + ) + + y = exp_x * cute.arch.rcp_approx(denom) + tXrO.store(y.to(tXrO.element_type)) + + tOpO = ( + predicate_k(thr_copy_store.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_store, tXrO, tXgO, pred=tOpO) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mO: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl(mX, mO, tv_layout, tiler_mn) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mO: cute.Tensor, + ) -> None: + tiler_mn, tv_layout = self._get_tv_layout() + self._kernel_impl(mX, mO, tv_layout, tiler_mn) + + +class SoftmaxBwdSM100(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # One stage for dot(dy, y) per row. + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32) + + def _calculate_threads_per_row(self) -> int: + # Match Quack backward softmax buckets. + N = self.N + if N in (4096, 6144): + return 128 + if N <= 64: + return 8 + if N <= 128: + return 16 + if N <= 3072: + return 32 + if N <= 6144: + return 64 + if N <= 8192: + return 128 + return 256 + + def _set_cluster_n(self) -> None: + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + def _get_num_threads(self) -> int: + # Slightly more aggressive threading for large N than the base class. + return 128 if self.N <= 8192 else 256 + + def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: + # Store both y and dy tiles plus reduction buffers and mbarriers. + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 + + self.stage + * num_warps + * self.cluster_n + * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) + ) + + @cute.jit + def __call__( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + stream: cuda.CUstream, + ) -> None: + assert mdY.element_type == self.dtype + assert mY.element_type == self.dtype + assert mdX.element_type == self.dtype + # Use the generic ReductionBase tiling with 128-bit vectorization. + tiler_mn, tv_layout = self._get_tv_layout() + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel(mdY, mY, mdX, tv_layout, tiler_mn) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mdY, mY, mdX) + ) + kernel.launch( + grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_dy: cute.Pointer, + ptr_y: cute.Pointer, + ptr_dx: cute.Pointer, + M: Int32, + ld: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + mdY = cute.make_tensor(ptr_dy, layout_mn) + mY = cute.make_tensor(ptr_y, layout_mn) + mdX = cute.make_tensor(ptr_dx, layout_mn) + self.__call__(mdY, mY, mdX, stream) + + @cute.jit + def _kernel_impl( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + shape = mdY.shape + idX = cute.make_identity_tensor(shape) + + gdY, gY, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) + for mT in (mdY, mY, mdX, idX) + ] + + smem = cutlass.utils.SmemAllocator() + sdY = smem.allocate_tensor( + mdY.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + sY = smem.allocate_tensor( + mY.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mdY.element_type, + num_bits_per_copy=128, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gdX.element_type, + num_bits_per_copy=128, + ) + + num_copy_elems = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems)) + thr_copy_load = cute.make_tiled_copy_tv( + copy_atom_load, thr_layout, val_layout + ).get_slice(tidx) + thr_copy_store = cute.make_tiled_copy_tv( + copy_atom_store, thr_layout, val_layout + ).get_slice(tidx) + + tdYgdY = thr_copy_load.partition_S(gdY) + tdYsdY = thr_copy_load.partition_D(sdY) + tYgY = thr_copy_load.partition_S(gY) + tYsY = thr_copy_load.partition_D(sY) + tdXgdX = thr_copy_store.partition_D(gdX) + tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None] + + tdYrdY, tYrY, tdXrdX = [ + cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX) + ] + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n) + tdYpdY = ( + predicate_k(thr_copy_load.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY) + cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + cute.autovec_copy(tdYsdY, tdYrdY) + cute.autovec_copy(tYsY, tYrY) + dy = tdYrdY.load().to(Float32) + y = tYrY.load().to(Float32) + dot = row_reduce( + dy * y, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + + dx = y * (dy - dot) + tdXrdX.store(dx.to(tdXrdX.element_type)) + + tdXpdX = ( + predicate_k(thr_copy_store.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn) + else: + + @cute.kernel + def kernel( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + ) -> None: + tiler_mn, tv_layout = self._get_tv_layout() + self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn) + + +class SoftmaxFwdBwdSM100(ReductionBase): + """Fused softmax forward+backward producing dx from (x, dy). + + Computes: + y = softmax(x) + dot = sum(dy * y) + dx = y * (dy - dot) + + This avoids materializing the intermediate `y` in global memory, which is + the dominant overhead in a naive `softmax_backward(dy, softmax_forward(x))` + composition. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # Online softmax reduction uses an Int64 reduction buffer packing + # (max, sum_exp) pairs. We allocate a separate Float32 reduction buffer + # for dot(dy, y). + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + + def _calculate_threads_per_row(self) -> int: + # Favor the backward bucket policy (better for the dot reduction). + N = self.N + if N in (4096, 6144): + return 128 + if N <= 64: + return 8 + if N <= 128: + return 16 + if N <= 3072: + return 32 + if N <= 6144: + return 64 + if N <= 8192: + return 128 + return 256 + + def _set_cluster_n(self) -> None: + # Quack-style growth of cluster_n with N and dtype. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + def _get_num_threads(self) -> int: + # Keep in sync with _calculate_threads_per_row. + return 128 if self.N <= 8192 else 256 + + def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: + # Allocation order: + # 1) sX (16B aligned) + # 2) sdY (16B aligned) + # 3) reduction_buffer_stats (8B aligned) + # 4) reduction_buffer_dot (8B aligned) + # 5) optional mbarrier array (8B aligned) + def _align_up(x: int, align: int) -> int: + return ((x + align - 1) // align) * align + + tile_bytes = int(cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))) + reduction_stats_bytes = int( + num_warps * self.cluster_n * (cutlass.Int64.width // 8) + ) + reduction_dot_bytes = int( + num_warps * self.cluster_n * (cutlass.Float32.width // 8) + ) + mbar_bytes = ( + int(2 * (cutlass.Int64.width // 8)) if const_expr(self.cluster_n > 1) else 0 + ) + + offset = _align_up(tile_bytes, 16) + offset = _align_up(offset, 16) + tile_bytes + offset = _align_up(offset, 8) + reduction_stats_bytes + offset = _align_up(offset, 8) + reduction_dot_bytes + offset = _align_up(offset, 8) + mbar_bytes + return int(offset) + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mdY: cute.Tensor, + mdX: cute.Tensor, + stream: cuda.CUstream, + ) -> None: + assert mX.element_type == self.dtype + assert mdY.element_type == self.dtype + assert mdX.element_type == self.dtype + tiler_mn, tv_layout = self._get_tv_layout() + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel(mX, mdY, mdX, tv_layout, tiler_mn) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mX, mdY, mdX) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_dy: cute.Pointer, + ptr_dx: cute.Pointer, + M: Int32, + ld: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + mX = cute.make_tensor(ptr_x, layout_mn) + mdY = cute.make_tensor(ptr_dy, layout_mn) + mdX = cute.make_tensor(ptr_dx, layout_mn) + self.__call__(mX, mdY, mdX, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mdY: cute.Tensor, + mdX: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_y = ( + const_expr(0) + if const_expr(self.cluster_n == 1) + else cute.arch.block_idx()[1] + ) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + gX, gdY, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) + for mT in (mX, mdY, mdX, idX) + ] + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + sdY = smem.allocate_tensor( + mdY.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + + reduction_layout = self._get_reduction_buffer_layout(tv_layout, self.cluster_n) + reduction_buffer_stats = smem.allocate_tensor( + cutlass.Int64, reduction_layout, byte_alignment=8 + ) + reduction_buffer_dot = smem.allocate_tensor( + cutlass.Float32, reduction_layout, byte_alignment=8 + ) + + if const_expr(self.cluster_n > 1): + mbar_ptr_base = smem.allocate_array(cutlass.Int64, num_elems=2) + mbar_ptr_stats = mbar_ptr_base + mbar_ptr_dot = mbar_ptr_base + Int32(1) + else: + mbar_ptr_stats = None + mbar_ptr_dot = None + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=128, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gdX.element_type, + num_bits_per_copy=128, + ) + + num_copy_elems = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems)) + thr_copy_load = cute.make_tiled_copy_tv( + copy_atom_load, thr_layout, val_layout + ).get_slice(tidx) + thr_copy_store = cute.make_tiled_copy_tv( + copy_atom_store, thr_layout, val_layout + ).get_slice(tidx) + + tXgX = thr_copy_load.partition_S(gX) + tXsX = thr_copy_load.partition_D(sX) + tdYgdY = thr_copy_load.partition_S(gdY) + tdYsdY = thr_copy_load.partition_D(sdY) + tdXgdX = thr_copy_store.partition_D(gdX) + tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None] + + tXrX, tdYrdY, tdXrdX = [ + cute.make_fragment_like(thr) for thr in (tXgX, tdYgdY, tdXgdX) + ] + + if const_expr( + self.cluster_n > 1 + and mbar_ptr_stats is not None + and mbar_ptr_dot is not None + ): + if tidx < 2: + cute.arch.mbarrier_init(mbar_ptr_stats + tidx, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_load.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load, tXgX, tXsX, pred=tXpX) + cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + fill_oob(tdYsdY, tXpX, 0.0) + + cute.autovec_copy(tXsX, tXrX) + cute.autovec_copy(tdYsdY, tdYrdY) + x = tXrX.load().to(Float32) + dy = tdYrdY.load().to(Float32) + + _, denom, exp_x = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer_stats[None, None, 0], + mbar_ptr_stats, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + phase=None, + return_exp_x=True, + ) + assert exp_x is not None + y = exp_x * cute.arch.rcp_approx(denom) + + dot = row_reduce( + dy * y, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer_dot[None, None, 0], + mbar_ptr_dot, + phase=None, + init_val=0.0, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + + dx = y * (dy - dot) + tdXrdX.store(dx.to(tdXrdX.element_type)) + + tOpO = ( + predicate_k(thr_copy_store.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tOpO) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mdY: cute.Tensor, + mdX: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl(mX, mdY, mdX, tv_layout, tiler_mn) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mdY: cute.Tensor, + mdX: cute.Tensor, + ) -> None: + tiler_mn, tv_layout = self._get_tv_layout() + self._kernel_impl(mX, mdY, mdX, tv_layout, tiler_mn) + + +def _convert_2d_tensor(x: Tensor) -> cute.Tensor: + # Match Quack's Softmax conversion exactly: assume 16B alignment and mark + # the shape compact with row-major stride order (0, 1), with mode=0 (batch). + # We intentionally do not call mark_layout_dynamic here to avoid the + # leading_dim stride==1 constraint used in RMSNorm. + return from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic( + mode=0, stride_order=(0, 1) + ) + + +def _can_use_ptr_path_2d(x: Tensor) -> bool: + """Conservative guard for the pointer-based fast path.""" + if not x.is_cuda or x.dim() != 2: + return False + if x.dtype not in TORCH2CUTE_DTYPE: + return False + # Require row-major last-dim contiguous. + if x.stride(1) != 1: + return False + # Require 16B alignment (matches from_dlpack(..., assumed_align=16)). + if (x.data_ptr() % 16) != 0: + return False + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + divby = 128 // dtype_x.width + # Softmax uses ReductionBase default num_copy_bits=128, so N must be divisible. + if (x.shape[1] % divby) != 0: + return False + # Ensure each row start remains aligned for 128-bit vectorized copies. + if (x.stride(0) % divby) != 0: + return False + return True + + +def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None: + """Launch the pointer-based Softmax forward kernel into preallocated `out`.""" + assert x.is_cuda and x.dim() == 2 + assert out.is_cuda and out.shape == x.shape and out.dtype == x.dtype + assert out.stride() == x.stride(), "Pointer path expects out to match x strides" + + M, N = x.shape + device_index = x.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + key = ("ptr_fwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_FWD_COMPILE_CACHE.get(key) + if compiled is None: + op = SoftmaxFwdSM100(dtype_x, int(N)) + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ld = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_out, + Int32(int(M)), + ld, + stream, + ) + _PTR_FWD_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_softmax_launcher( + compiled=compiled, + dtype=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + assumed_align=16, + is_bwd=False, + ) + if launcher is not None: + launcher.launch( + a_ptr=int(x.data_ptr()), + b_ptr=int(out.data_ptr()), + c_ptr=None, + M=int(M), + ld=int(x.stride(0)), + stream_handle=stream_handle, + dtype=dtype_x, + ) + return + + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + compiled(ptr_x, ptr_out, Int32(int(M)), Int32(int(x.stride(0))), stream) + + +def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None: + """Launch the pointer-based Softmax backward kernel into preallocated `dx`.""" + assert dy.is_cuda and dy.dim() == 2 + assert y.is_cuda and y.shape == dy.shape and y.dtype == dy.dtype + assert dx.is_cuda and dx.shape == dy.shape and dx.dtype == dy.dtype + assert dy.stride() == y.stride() == dx.stride(), ( + "Pointer path expects matching strides" + ) + + M, N = dy.shape + device_index = dy.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[dy.dtype] + key = ("ptr_bwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_BWD_COMPILE_CACHE.get(key) + if compiled is None: + op = SoftmaxBwdSM100(dtype_x, int(N)) + ptr_dy = rt.make_ptr( + dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_y = rt.make_ptr( + dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ld = Int32(int(dy.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_dy, + ptr_y, + ptr_dx, + Int32(int(M)), + ld, + stream, + ) + _PTR_BWD_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_softmax_launcher( + compiled=compiled, + dtype=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + assumed_align=16, + is_bwd=True, + ) + if launcher is not None: + launcher.launch( + a_ptr=int(dy.data_ptr()), + b_ptr=int(y.data_ptr()), + c_ptr=int(dx.data_ptr()), + M=int(M), + ld=int(dy.stride(0)), + stream_handle=stream_handle, + dtype=dtype_x, + ) + return + + ptr_dy = rt.make_ptr( + dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_y = rt.make_ptr( + dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + compiled(ptr_dy, ptr_y, ptr_dx, Int32(int(M)), Int32(int(dy.stride(0))), stream) + + +def _softmax_fwd_bwd_ptr_into(*, x: Tensor, dy: Tensor, dx: Tensor) -> None: + """Launch the fused pointer-based Softmax fwd+bwd kernel into preallocated `dx`.""" + assert x.is_cuda and x.dim() == 2 + assert dy.is_cuda and dy.shape == x.shape and dy.dtype == x.dtype + assert dx.is_cuda and dx.shape == x.shape and dx.dtype == x.dtype + assert x.stride() == dy.stride() == dx.stride(), ( + "Pointer path expects matching strides" + ) + + M, N = x.shape + device_index = x.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + key = ("ptr_fwd_bwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_FWDBWD_COMPILE_CACHE.get(key) + if compiled is None: + op = SoftmaxFwdBwdSM100(dtype_x, int(N)) + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dy = rt.make_ptr( + dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ld = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_dy, + ptr_dx, + Int32(int(M)), + ld, + stream, + ) + _PTR_FWDBWD_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_softmax_launcher( + compiled=compiled, + dtype=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + assumed_align=16, + is_bwd=True, + ) + if launcher is not None: + launcher.launch( + a_ptr=int(x.data_ptr()), + b_ptr=int(dy.data_ptr()), + c_ptr=int(dx.data_ptr()), + M=int(M), + ld=int(x.stride(0)), + stream_handle=stream_handle, + dtype=dtype_x, + ) + return + + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dy = rt.make_ptr( + dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + compiled(ptr_x, ptr_dy, ptr_dx, Int32(int(M)), Int32(int(x.stride(0))), stream) + + +def softmax_forward(x: Tensor) -> Tensor: + """SM100 CuteDSL softmax forward pass: y = softmax(x, dim=-1).""" + assert x.dim() == 2, "Input must be 2D (M, N)" + assert x.is_cuda, "Input must be on CUDA device" + assert x.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype" + + N = x.size(1) + dtype = TORCH2CUTE_DTYPE[x.dtype] + if _can_use_ptr_path_2d(x): + out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + _softmax_forward_ptr_into(x=x, out=out) + return out + + out = torch.empty_like(x) + + x_tensor = _convert_2d_tensor(x) + out_tensor = _convert_2d_tensor(out) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = (dtype, N) + kernel = _FWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = SoftmaxFwdSM100(dtype, N) + kernel = cute.compile(op, x_tensor, out_tensor, current_stream) + _FWD_COMPILE_CACHE[compile_key] = kernel + kernel(x_tensor, out_tensor, current_stream) + return out + + +def softmax_backward(dy: Tensor, y: Tensor) -> Tensor: + """SM100 CuteDSL softmax backward pass.""" + assert dy.dim() == 2 and y.dim() == 2, "dy and y must be 2D (M, N)" + assert dy.shape == y.shape, "dy and y must have the same shape" + assert dy.is_cuda and y.is_cuda, "dy and y must be on CUDA device" + assert dy.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype" + assert y.dtype == dy.dtype, "dy and y must have the same dtype" + + N = dy.size(1) + dtype = TORCH2CUTE_DTYPE[dy.dtype] + if ( + _can_use_ptr_path_2d(dy) + and _can_use_ptr_path_2d(y) + and dy.stride() == y.stride() + ): + dx = torch.empty_strided( + dy.shape, dy.stride(), device=dy.device, dtype=dy.dtype + ) + _softmax_backward_ptr_into(dy=dy, y=y, dx=dx) + return dx + + dx = torch.empty_like(dy) + + dy_tensor = _convert_2d_tensor(dy) + y_tensor = _convert_2d_tensor(y) + dx_tensor = _convert_2d_tensor(dx) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = (dtype, N) + kernel = _BWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = SoftmaxBwdSM100(dtype, N) + kernel = cute.compile(op, dy_tensor, y_tensor, dx_tensor, current_stream) + _BWD_COMPILE_CACHE[compile_key] = kernel + kernel(dy_tensor, y_tensor, dx_tensor, current_stream) + return dx + + +def softmax_fwd_bwd(dy: Tensor, x: Tensor) -> Tensor: + """Fused softmax forward+backward producing ``dx`` from ``(x, dy)``. + + This is intended for benchmarks and training-like use-cases where the + intermediate ``y = softmax(x)`` is not needed outside the backward pass. + """ + assert x.dim() == 2 and dy.dim() == 2, "x and dy must be 2D (M, N)" + assert x.shape == dy.shape, "x and dy must have the same shape" + assert x.is_cuda and dy.is_cuda, "x and dy must be on CUDA device" + assert x.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype" + assert dy.dtype == x.dtype, "x and dy must have the same dtype" + + if ( + _can_use_ptr_path_2d(x) + and _can_use_ptr_path_2d(dy) + and x.stride() == dy.stride() + ): + dx = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + _softmax_fwd_bwd_ptr_into(x=x, dy=dy, dx=dx) + return dx + + with torch.no_grad(): + return softmax_backward(dy, softmax_forward(x)) + + +class SoftmaxFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + y = softmax_forward(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy: Tensor) -> tuple[Tensor]: + (y,) = ctx.saved_tensors + dx = softmax_backward(dy, y) + return dx + + +def softmax(x: Tensor) -> Tensor: + """Autograd-friendly softmax using the SM100 CuteDSL kernel.""" + return SoftmaxFunction.apply(x) + + +def _torch_softmax_reference(x: Tensor) -> Tensor: + return torch.nn.functional.softmax(x, dim=-1) + + +def verify_softmax_parity( + M: int, + N: int, + dtype: torch.dtype = torch.bfloat16, + atol: float = 5e-2, + rtol: float = 5e-2, +) -> None: + """Compare SM100 CuteDSL softmax against PyTorch for a single shape.""" + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + x.requires_grad_(True) + + # Forward parity + y_ref = _torch_softmax_reference(x) + y = softmax(x) + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + + # Backward parity + dy = torch.randn_like(y) + (dx_ref,) = torch.autograd.grad(y_ref, x, dy, retain_graph=False) + dx = softmax_backward(dy, y) + torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol)