Skip to content

thunderfx: Applying thunderfx should return a nn.Module like object. #2757

@kshitij12345

Description

@kshitij12345

This should allow easily applying thunderfx to only a few submodules.

Repro

import thunder
import torch
from typing import Callable

model = torch.nn.Sequential(
    torch.nn.Linear(10, 20),
    torch.nn.ReLU(),
    torch.nn.Linear(20, 10)
)

cm = torch.compile(model)
print(type(cm))  # <class 'torch._dynamo.eval_frame.OptimizedModule'>
# print(cm.forward)

from thunder.dynamo import thunderfx

tm = thunderfx(model)
print(type(tm))
# Not available
# print(tm.forward)  # Errors

Repro to replace submodules

import thunder
import torch
from typing import Callable

# The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230
def _replace_with_custom_fn_if_matches_filter_with_name(
    model,
    replacement_fn: Callable[[torch.nn.Module, str], torch.nn.Module],
    filter_fn: Callable[[torch.nn.Module, str], bool],
    cur_fqn="",
) -> None:
    """
    Recursively replaces each child module in `model` with the result of `replacement_fn(child)`

    replacement_fn (Callable[[torch.nn.Module, str], torch.nn.Module]): The function to replace matching modules.
    filter_fn (Callable[[torch.nn.Module, str], bool]): The function to filter matching modules.
    cur_fqn (str): The current fully qualified name of the module.

    Returns:
        None
    """
    if filter_fn(model, cur_fqn[:-1]):
        model = replacement_fn(model, cur_fqn[:-1])
        return model
    else:
        named_children_list = list(model.named_children())
        for name, child in named_children_list:
            new_child = _replace_with_custom_fn_if_matches_filter_with_name(
                child,
                replacement_fn,
                filter_fn,
                f"{cur_fqn}{name}.",
            )
            if new_child is not child:
                setattr(model, name, new_child)
        return model

# works
model = torch.nn.Sequential(
    torch.nn.Linear(10, 20),
    torch.nn.ReLU(),
    torch.nn.Linear(20, 10)
)
_replace_with_custom_fn_if_matches_filter_with_name(model, replacement_fn=lambda module, name: torch.compile(module), filter_fn=lambda module, name: isinstance(module, torch.nn.ReLU))

# works
model = torch.nn.Sequential(
    torch.nn.Linear(10, 20),
    torch.nn.ReLU(),
    torch.nn.Linear(20, 10)
)

from thunder.dynamo import ThunderCompiler
_replace_with_custom_fn_if_matches_filter_with_name(model, replacement_fn=lambda module, name: torch.compile(module, backend=ThunderCompiler()), filter_fn=lambda module, name: isinstance(module, torch.nn.ReLU))

# doesn't work
model = torch.nn.Sequential(
    torch.nn.Linear(10, 20),
    torch.nn.ReLU(),
    torch.nn.Linear(20, 10)
)
_replace_with_custom_fn_if_matches_filter_with_name(model, replacement_fn=lambda module, name: thunderfx(module), filter_fn=lambda module, name: isinstance(module, torch.nn.ReLU))

Alternative:
torch.compile(backend=ThunderCompiler()) works.

cc @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    thunderfxfor things that could be applicable to the dynamo+thunder frontendux

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions