-
Notifications
You must be signed in to change notification settings - Fork 110
Open
Labels
thunderfxfor things that could be applicable to the dynamo+thunder frontendfor things that could be applicable to the dynamo+thunder frontendux
Description
This should allow easily applying thunderfx to only a few submodules.
Repro
import thunder
import torch
from typing import Callable
model = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 10)
)
cm = torch.compile(model)
print(type(cm)) # <class 'torch._dynamo.eval_frame.OptimizedModule'>
# print(cm.forward)
from thunder.dynamo import thunderfx
tm = thunderfx(model)
print(type(tm))
# Not available
# print(tm.forward) # ErrorsRepro to replace submodules
import thunder
import torch
from typing import Callable
# The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230
def _replace_with_custom_fn_if_matches_filter_with_name(
model,
replacement_fn: Callable[[torch.nn.Module, str], torch.nn.Module],
filter_fn: Callable[[torch.nn.Module, str], bool],
cur_fqn="",
) -> None:
"""
Recursively replaces each child module in `model` with the result of `replacement_fn(child)`
replacement_fn (Callable[[torch.nn.Module, str], torch.nn.Module]): The function to replace matching modules.
filter_fn (Callable[[torch.nn.Module, str], bool]): The function to filter matching modules.
cur_fqn (str): The current fully qualified name of the module.
Returns:
None
"""
if filter_fn(model, cur_fqn[:-1]):
model = replacement_fn(model, cur_fqn[:-1])
return model
else:
named_children_list = list(model.named_children())
for name, child in named_children_list:
new_child = _replace_with_custom_fn_if_matches_filter_with_name(
child,
replacement_fn,
filter_fn,
f"{cur_fqn}{name}.",
)
if new_child is not child:
setattr(model, name, new_child)
return model
# works
model = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 10)
)
_replace_with_custom_fn_if_matches_filter_with_name(model, replacement_fn=lambda module, name: torch.compile(module), filter_fn=lambda module, name: isinstance(module, torch.nn.ReLU))
# works
model = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 10)
)
from thunder.dynamo import ThunderCompiler
_replace_with_custom_fn_if_matches_filter_with_name(model, replacement_fn=lambda module, name: torch.compile(module, backend=ThunderCompiler()), filter_fn=lambda module, name: isinstance(module, torch.nn.ReLU))
# doesn't work
model = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 10)
)
_replace_with_custom_fn_if_matches_filter_with_name(model, replacement_fn=lambda module, name: thunderfx(module), filter_fn=lambda module, name: isinstance(module, torch.nn.ReLU))Alternative:
torch.compile(backend=ThunderCompiler()) works.
cc @Borda
Metadata
Metadata
Assignees
Labels
thunderfxfor things that could be applicable to the dynamo+thunder frontendfor things that could be applicable to the dynamo+thunder frontendux