Skip to content

Commit ccd3be9

Browse files
authored
hotfix: revert torch.library register (#709)
We observe performance degradation for small operations in flashinfer v0.2 because of the overhead of `torch.library.custom_op` introduced in #554. This PR disables torch custom operator registrations first, we can add them back with lightweight registration later: https://github.com/vllm-project/vllm/blob/36e76700453924c8d421db99af70a88a1df835cd/vllm/utils.py#L1660-L1674 cc @zhyncs @abcdabcd987 @youkaichao
1 parent 4ba91c0 commit ccd3be9

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

flashinfer/utils.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -236,19 +236,24 @@ def register_custom_op(
236236
device_types: Optional[Union[str, Sequence[str]]] = None,
237237
schema: Optional[str] = None,
238238
) -> Callable:
239-
return torch.library.custom_op(
240-
name,
241-
fn,
242-
mutates_args=mutates_args,
243-
device_types=device_types,
244-
schema=schema,
245-
)
239+
# NOTE(Zihao): torch.library.custom_op has significant overhead as mentioned in the following link
240+
# https://github.com/vllm-project/vllm/blob/36e76700453924c8d421db99af70a88a1df835cd/vllm/utils.py#L1660-L1674
241+
242+
# return torch.library.custom_op(
243+
# name,
244+
# fn,
245+
# mutates_args=mutates_args,
246+
# device_types=device_types,
247+
# schema=schema,
248+
# )
249+
return lambda x: x
246250

247251
def register_fake_op(
248252
name: str,
249253
fn: Optional[Callable] = None,
250254
) -> Callable:
251-
return torch.library.register_fake(name, fn)
255+
# return torch.library.register_fake(name, fn)
256+
return lambda x: x
252257

253258

254259
def get_cuda_stream(device: torch.device) -> int:

0 commit comments

Comments
 (0)