Skip to content

[torch.compile] Make HiDream torch.compile ready #11477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def forward(self, x):
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts)
tokens_per_expert = count_freq.cumsum(dim=0)

Comment on lines -392 to +394
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reimplemented it to eliminate the numpy() dependency.

token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
Expand Down
21 changes: 21 additions & 0 deletions tests/models/transformers/test_models_transformer_hidream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from diffusers import HiDreamImageTransformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)

Expand Down Expand Up @@ -94,3 +98,20 @@ def test_set_attn_processor_for_determinism(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HiDreamImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relevant test for this PR.

torch._dynamo.reset()
torch._dynamo.config.capture_dynamic_output_shape_ops = True

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)