Skip to content
21 changes: 11 additions & 10 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ def create_fd(
# TODO Review splititng very large fusions or removing the max length restriction completely
# See "Very large nvFuser fusions hit max_length"
fd = FusionDefinition(max_length=9999)

# Device may be set in one of the "factory" methods like full, iota, or uniform
fd._selected_device = None

with fd:
# NOTE Adding constants is disabled for the moment in favor of definining them inline
# 0) Adds constants
Expand Down Expand Up @@ -428,7 +432,7 @@ def to_runtime_descriptors(args) -> tuple:


# TODO Consider making this just a function, because it's faster to call a function than a callable class
@dataclass
@dataclass(slots=True)
class FusionDefinitionWrapper:
"""
A callable object wrapping a nvFuser fusion definition.
Expand Down Expand Up @@ -456,16 +460,13 @@ def __call__(self, *args):
if self.store_inputs:
self.last_inputs = args

kwargs = {}
if self.save_fake_inputs:
kwargs["save_repro_inputs"] = True
# Set device if set in one of the "factory" methods like full, iota, or uniform
if hasattr(fd, "_selected_device"):
kwargs["device"] = fd._selected_device

with annotate_for_profile(self.name):
return fd.execute(
args, _enable_options=self.enable_options, _disable_options=self.disable_options, **kwargs
args,
device=fd._selected_device,
save_repro_inputs=self.save_fake_inputs,
_enable_options=self.enable_options,
_disable_options=self.disable_options,
)

def __repr__(self):
Expand Down Expand Up @@ -1012,7 +1013,7 @@ def _select_device(fd: FusionDefinition, device: None | Device):
return

utils.check(
not hasattr(fd, "_selected_device") or fd._selected_device == device.index,
fd._selected_device is None or fd._selected_device == device.index,
lambda: f"Found multiple requested devices: {fd._selected_device} and {device.index}",
)

Expand Down
Loading