Skip to content

Commit d158717

Browse files
authored
bugfix: casting int array to int32 for rope input arguments (#697)
To avoid the potential bugs when user pass LongTensor to rope APIs. Also remove some files that are not used in current codebase. Fixed a bug in AOT mode that `apply_rope_pos_ids_cos_sin_cache` was not registered in pybind.
1 parent 398cd2b commit d158717

File tree

4 files changed

+9
-263
lines changed

4 files changed

+9
-263
lines changed

csrc/dispatch_type_code.h

-192
This file was deleted.

csrc/dispatch_utils.h

-71
This file was deleted.

csrc/flashinfer_ops.cu

+2
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
254254
m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids");
255255
m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids,
256256
"Apply Llama 3.1 style RoPE with positional ids");
257+
m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache,
258+
"Apply RoPE with positional ids and cosine/sine cache");
257259

258260
// sampling
259261
m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities");

flashinfer/rope.py

+7
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def _apply_rope(
5656
rope_theta: float,
5757
) -> None:
5858
with q.device as device:
59+
indptr = indptr.int()
60+
offsets = offsets.int()
5961
get_rope_module().apply_rope(
6062
q,
6163
k,
@@ -104,6 +106,8 @@ def _apply_llama31_rope(
104106
old_context_len: float,
105107
) -> None:
106108
with q.device as device:
109+
indptr = indptr.int()
110+
offsets = offsets.int()
107111
get_rope_module().apply_llama31_rope(
108112
q,
109113
k,
@@ -154,6 +158,7 @@ def _apply_rope_pos_ids(
154158
rope_theta: float,
155159
) -> None:
156160
with q.device as device:
161+
pos_ids = pos_ids.int()
157162
get_rope_module().apply_rope_pos_ids(
158163
q,
159164
k,
@@ -197,6 +202,7 @@ def _apply_rope_pos_ids_cos_sin_cache(
197202
interleave: bool,
198203
) -> None:
199204
with q.device as device:
205+
pos_ids = pos_ids.int()
200206
get_rope_module().apply_rope_pos_ids_cos_sin_cache(
201207
q,
202208
k,
@@ -242,6 +248,7 @@ def _apply_llama31_rope_pos_ids(
242248
old_context_len: float,
243249
) -> None:
244250
with q.device as device:
251+
pos_ids = pos_ids.int()
245252
get_rope_module().apply_llama31_rope_pos_ids(
246253
q,
247254
k,

0 commit comments

Comments
 (0)