Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 50 additions & 29 deletions tests/pipelines/wan/test_wan_22.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline, WanTransformer3DModel

from ...testing_utils import (
require_accelerator,
enable_full_determinism,
torch_device,
)
Expand Down Expand Up @@ -51,51 +52,33 @@ class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
supports_dduf = False

def get_dummy_components(self):
def get_dummy_components(self, dtype=torch.float32):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

@kaixuanliu kaixuanliu Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls refer to L246-L256 (Sorry I only found Chinese version for this explanation). Using torch.Tensor.to method will convert all weights, while using torch_dtype parameter with from_pretrained will preserve layers in _keep_in_fp32_modules. For wan models, all components of pipe will be fp16 dtype while it is not the case for pipe_loaded. Here I override test_save_load_float16 function seperately for wan models.

torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
).to(dtype=dtype)

torch.manual_seed(0)
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", torch_dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

torch.manual_seed(0)
transformer = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
# Use from_pretrained with a tiny model to ensure proper dtype handling
# This ensures _keep_in_fp32_modules and _skip_layerwise_casting_patterns are respected
transformer = WanTransformer3DModel.from_pretrained(
"Kaixuanliu/tiny-random-wan-transformer",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls replace my model space. We have to use from_pretrained here to make all the submodules' dtype correctly loaded.

torch_dtype=dtype
)

torch.manual_seed(0)
transformer_2 = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
transformer_2 = WanTransformer3DModel.from_pretrained(
"Kaixuanliu/tiny-random-wan-transformer",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

torch_dtype=dtype
)

components = {
Expand Down Expand Up @@ -155,6 +138,44 @@ def test_inference(self):
def test_attention_slicing_forward_pass(self):
pass

@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't know then how on my end the tests are passing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be related with the input. When I set all the seed in get_dummy_components to 1, the max_diff on A100 is np.float16(0.2366), and when set seed to 42, the output will be all nan value. After this PR, the max_diff will all be 0 for all the seed

# Use get_dummy_components with dtype parameter instead of converting components
components = self.get_dummy_components(dtype=torch.float16)
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)

for name, component in pipe_loaded.components.items():
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)

inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)

def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = "transformer"

Expand Down