20
20
21
21
using namespace flashinfer ;
22
22
23
- void rmsnorm (at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl,
24
- int64_t cuda_stream ) {
23
+ void rmsnorm (at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps,
24
+ bool enable_pdl ) {
25
25
CHECK_LAST_DIM_CONTIGUOUS_INPUT (input);
26
26
CHECK_LAST_DIM_CONTIGUOUS_INPUT (weight);
27
27
auto device = input.device ();
@@ -34,7 +34,8 @@ void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double e
34
34
CHECK_EQ (output.size (0 ), batch_size);
35
35
CHECK_EQ (output.size (1 ), hidden_size);
36
36
37
- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
37
+ const c10::cuda::OptionalCUDAGuard device_guard (device);
38
+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
38
39
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 (input.scalar_type (), c_type, [&] {
39
40
cudaError_t status = norm::RMSNorm (
40
41
static_cast <c_type*>(input.data_ptr ()), static_cast <c_type*>(weight.data_ptr ()),
@@ -47,7 +48,7 @@ void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double e
47
48
}
48
49
49
50
void fused_add_rmsnorm (at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
50
- bool enable_pdl, int64_t cuda_stream ) {
51
+ bool enable_pdl) {
51
52
CHECK_LAST_DIM_CONTIGUOUS_INPUT (input);
52
53
CHECK_LAST_DIM_CONTIGUOUS_INPUT (residual);
53
54
CHECK_LAST_DIM_CONTIGUOUS_INPUT (weight);
@@ -63,7 +64,8 @@ void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weig
63
64
unsigned int batch_size = input.size (0 );
64
65
unsigned int hidden_size = input.size (1 );
65
66
66
- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
67
+ const c10::cuda::OptionalCUDAGuard device_guard (device);
68
+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
67
69
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 (input.scalar_type (), c_type, [&] {
68
70
cudaError_t status = norm::FusedAddRMSNorm (
69
71
static_cast <c_type*>(input.data_ptr ()), static_cast <c_type*>(residual.data_ptr ()),
@@ -76,7 +78,7 @@ void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weig
76
78
}
77
79
78
80
void gemma_rmsnorm (at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps,
79
- bool enable_pdl, int64_t cuda_stream ) {
81
+ bool enable_pdl) {
80
82
CHECK_LAST_DIM_CONTIGUOUS_INPUT (input);
81
83
CHECK_LAST_DIM_CONTIGUOUS_INPUT (weight);
82
84
auto device = input.device ();
@@ -89,7 +91,8 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
89
91
CHECK_EQ (output.size (0 ), batch_size);
90
92
CHECK_EQ (output.size (1 ), hidden_size);
91
93
92
- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
94
+ const c10::cuda::OptionalCUDAGuard device_guard (device);
95
+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
93
96
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 (input.scalar_type (), c_type, [&] {
94
97
cudaError_t status = norm::GemmaRMSNorm (
95
98
static_cast <c_type*>(input.data_ptr ()), static_cast <c_type*>(weight.data_ptr ()),
@@ -102,7 +105,7 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
102
105
}
103
106
104
107
void gemma_fused_add_rmsnorm (at::Tensor& input, at::Tensor& residual, at::Tensor& weight,
105
- double eps, bool enable_pdl, int64_t cuda_stream ) {
108
+ double eps, bool enable_pdl) {
106
109
CHECK_LAST_DIM_CONTIGUOUS_INPUT (input);
107
110
CHECK_LAST_DIM_CONTIGUOUS_INPUT (residual);
108
111
CHECK_LAST_DIM_CONTIGUOUS_INPUT (weight);
@@ -118,7 +121,8 @@ void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor
118
121
unsigned int batch_size = input.size (0 );
119
122
unsigned int hidden_size = input.size (1 );
120
123
121
- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
124
+ const c10::cuda::OptionalCUDAGuard device_guard (device);
125
+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
122
126
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 (input.scalar_type (), c_type, [&] {
123
127
cudaError_t status = norm::GemmaFusedAddRMSNorm (
124
128
static_cast <c_type*>(input.data_ptr ()), static_cast <c_type*>(residual.data_ptr ()),
0 commit comments