Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions tests/pipelines/wan/test_wan_22.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
supports_dduf = False

def get_dummy_components(self):
def get_dummy_components(self, dtype=torch.float32):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

@kaixuanliu kaixuanliu Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls refer to L246-L256 (Sorry I only found Chinese version for this explanation). Using torch.Tensor.to method will convert all weights, while using torch_dtype parameter with from_pretrained will preserve layers in _keep_in_fp32_modules. For wan models, all components of pipe will be fp16 dtype while it is not the case for pipe_loaded. Here I override test_save_load_float16 function seperately for wan models.

torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
).to(dtype=dtype)

torch.manual_seed(0)
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

torch.manual_seed(0)
Expand All @@ -80,7 +80,7 @@ def get_dummy_components(self):
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
).to(dtype=dtype)

torch.manual_seed(0)
transformer_2 = WanTransformer3DModel(
Expand All @@ -96,7 +96,7 @@ def get_dummy_components(self):
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
).to(dtype=dtype)

components = {
"transformer": transformer,
Expand Down Expand Up @@ -155,6 +155,43 @@ def test_inference(self):
def test_attention_slicing_forward_pass(self):
pass

@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
def test_save_load_float16(self, expected_max_diff=1e-2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't know then how on my end the tests are passing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be related with the input. When I set all the seed in get_dummy_components to 1, the max_diff on A100 is np.float16(0.2366), and when set seed to 42, the output will be all nan value. After this PR, the max_diff will all be 0 for all the seed

# Use get_dummy_components with dtype parameter instead of converting components
components = self.get_dummy_components(dtype=torch.float16)
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)

for name, component in pipe_loaded.components.items():
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)

inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)

def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = "transformer"

Expand Down