Skip to content

Commit 32dc1e9

Browse files
committed
upd
1 parent 7977474 commit 32dc1e9

File tree

5 files changed

+38
-42
lines changed

5 files changed

+38
-42
lines changed

csrc/flashinfer_norm_ops.cu

+4-5
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,16 @@
1515
*/
1616
#include "pytorch_extension_utils.h"
1717

18-
void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl,
19-
int64_t cuda_stream);
18+
void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
2019

2120
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
22-
bool enable_pdl, int64_t cuda_stream);
21+
bool enable_pdl);
2322

2423
void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps,
25-
bool enable_pdl, int64_t cuda_stream);
24+
bool enable_pdl);
2625

2726
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight,
28-
double eps, bool enable_pdl, int64_t cuda_stream);
27+
double eps, bool enable_pdl);
2928

3029
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
3130
// Root mean square normalization

csrc/flashinfer_ops.cu

+4-5
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,16 @@ void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at
6666

6767
//========== norm ==========
6868

69-
void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl,
70-
int64_t cuda_stream);
69+
void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
7170

7271
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
73-
bool enable_pdl, int64_t cuda_stream);
72+
bool enable_pdl);
7473

7574
void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps,
76-
bool enable_pdl, int64_t cuda_stream);
75+
bool enable_pdl);
7776

7877
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight,
79-
double eps, bool enable_pdl, int64_t cuda_stream);
78+
double eps, bool enable_pdl);
8079

8180
//========== page ==========
8281

csrc/norm.cu

+13-9
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
using namespace flashinfer;
2222

23-
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl,
24-
int64_t cuda_stream) {
23+
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps,
24+
bool enable_pdl) {
2525
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
2626
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
2727
auto device = input.device();
@@ -34,7 +34,8 @@ void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double e
3434
CHECK_EQ(output.size(0), batch_size);
3535
CHECK_EQ(output.size(1), hidden_size);
3636

37-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
37+
const c10::cuda::OptionalCUDAGuard device_guard(device);
38+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
3839
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
3940
cudaError_t status = norm::RMSNorm(
4041
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
@@ -47,7 +48,7 @@ void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double e
4748
}
4849

4950
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
50-
bool enable_pdl, int64_t cuda_stream) {
51+
bool enable_pdl) {
5152
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
5253
CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual);
5354
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
@@ -63,7 +64,8 @@ void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weig
6364
unsigned int batch_size = input.size(0);
6465
unsigned int hidden_size = input.size(1);
6566

66-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
67+
const c10::cuda::OptionalCUDAGuard device_guard(device);
68+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
6769
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
6870
cudaError_t status = norm::FusedAddRMSNorm(
6971
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
@@ -76,7 +78,7 @@ void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weig
7678
}
7779

7880
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps,
79-
bool enable_pdl, int64_t cuda_stream) {
81+
bool enable_pdl) {
8082
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
8183
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
8284
auto device = input.device();
@@ -89,7 +91,8 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
8991
CHECK_EQ(output.size(0), batch_size);
9092
CHECK_EQ(output.size(1), hidden_size);
9193

92-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
94+
const c10::cuda::OptionalCUDAGuard device_guard(device);
95+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
9396
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
9497
cudaError_t status = norm::GemmaRMSNorm(
9598
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
@@ -102,7 +105,7 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
102105
}
103106

104107
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight,
105-
double eps, bool enable_pdl, int64_t cuda_stream) {
108+
double eps, bool enable_pdl) {
106109
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
107110
CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual);
108111
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
@@ -118,7 +121,8 @@ void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor
118121
unsigned int batch_size = input.size(0);
119122
unsigned int hidden_size = input.size(1);
120123

121-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
124+
const c10::cuda::OptionalCUDAGuard device_guard(device);
125+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
122126
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
123127
cudaError_t status = norm::GemmaFusedAddRMSNorm(
124128
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),

csrc/pytorch_extension_utils.h

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
#pragma once
1717
#include <Python.h>
18+
#include <c10/cuda/CUDAGuard.h>
19+
#include <c10/cuda/CUDAStream.h>
1820
#include <torch/library.h>
1921

2022
#ifdef FLASHINFER_ENABLE_BF16

flashinfer/norm.py

+15-23
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Optional
17+
from functools import cache
18+
from typing import Any, Optional
1819

1920
import torch
2021

2122
from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops
22-
from .utils import get_cuda_stream, register_custom_op, register_fake_op
23+
from .utils import register_custom_op, register_fake_op
2324

2425
_norm_module = None
2526

@@ -42,6 +43,14 @@ def get_norm_module():
4243
return _norm_module
4344

4445

46+
@cache
47+
def get_module_attr(attr: str) -> Any:
48+
global _norm_module
49+
if _norm_module is None:
50+
get_norm_module()
51+
return getattr(_norm_module, attr).default
52+
53+
4554
def rmsnorm(
4655
input: torch.Tensor,
4756
weight: torch.Tensor,
@@ -86,10 +95,7 @@ def _rmsnorm(
8695
eps: float,
8796
enable_pdl: bool,
8897
) -> None:
89-
with input.device as device: # device guard
90-
get_norm_module().rmsnorm.default(
91-
out, input, weight, eps, enable_pdl, get_cuda_stream(device)
92-
)
98+
get_module_attr("rmsnorm")(out, input, weight, eps, enable_pdl)
9399

94100

95101
@register_fake_op("flashinfer::rmsnorm")
@@ -103,9 +109,6 @@ def _rmsnorm_fake(
103109
pass
104110

105111

106-
_fused_add_rmsnorm_kernel = None
107-
108-
109112
@register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
110113
def fused_add_rmsnorm(
111114
input: torch.Tensor,
@@ -136,12 +139,7 @@ def fused_add_rmsnorm(
136139
Whether to enable `programmatic dependent launch
137140
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
138141
"""
139-
global _fused_add_rmsnorm_kernel
140-
if _fused_add_rmsnorm_kernel is None:
141-
_fused_add_rmsnorm_kernel = get_norm_module().fused_add_rmsnorm.default
142-
_fused_add_rmsnorm_kernel(
143-
input, residual, weight, eps, enable_pdl, get_cuda_stream(input.device)
144-
)
142+
get_module_attr("fused_add_rmsnorm")(input, residual, weight, eps, enable_pdl)
145143

146144

147145
@register_fake_op("flashinfer::fused_add_rmsnorm")
@@ -199,10 +197,7 @@ def _gemma_rmsnorm(
199197
eps: float,
200198
enable_pdl: bool,
201199
) -> None:
202-
with input.device as device: # device guard
203-
get_norm_module().gemma_rmsnorm.default(
204-
out, input, weight, eps, enable_pdl, get_cuda_stream(device)
205-
)
200+
get_module_attr("gemma_rmsnorm")(out, input, weight, eps, enable_pdl)
206201

207202

208203
@register_fake_op("flashinfer::gemma_rmsnorm")
@@ -248,10 +243,7 @@ def gemma_fused_add_rmsnorm(
248243
Whether to enable `programmatic dependent launch
249244
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
250245
"""
251-
with input.device as device:
252-
get_norm_module().gemma_fused_add_rmsnorm.default(
253-
input, residual, weight, eps, enable_pdl, get_cuda_stream(device)
254-
)
246+
get_module_attr("gemma_fused_add_rmsnorm")(input, residual, weight, eps, enable_pdl)
255247

256248

257249
@register_fake_op("flashinfer::gemma_fused_add_rmsnorm")

0 commit comments

Comments
 (0)