-
Notifications
You must be signed in to change notification settings - Fork 6.5k
adjust unit tests for test_save_load_float16
#12500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
b938d30
2244e23
906adf7
5305169
a1d659c
ecd4c8b
62f3428
6b697ed
91bdabf
3f2ab46
4bdfa35
b6e5a28
6ec93a7
1baf156
0ee299a
1a8dd43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline, WanTransformer3DModel | ||
|
|
||
| from ...testing_utils import ( | ||
| require_accelerator, | ||
| enable_full_determinism, | ||
| torch_device, | ||
| ) | ||
|
|
@@ -51,51 +52,33 @@ class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase): | |
| test_xformers_attention = False | ||
| supports_dduf = False | ||
|
|
||
| def get_dummy_components(self): | ||
| def get_dummy_components(self, dtype=torch.float32): | ||
| torch.manual_seed(0) | ||
| vae = AutoencoderKLWan( | ||
| base_dim=3, | ||
| z_dim=16, | ||
| dim_mult=[1, 1, 1, 1], | ||
| num_res_blocks=1, | ||
| temperal_downsample=[False, True, True], | ||
| ) | ||
| ).to(dtype=dtype) | ||
|
|
||
| torch.manual_seed(0) | ||
| scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) | ||
| text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") | ||
| text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", torch_dtype=dtype) | ||
| tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") | ||
|
|
||
| torch.manual_seed(0) | ||
| transformer = WanTransformer3DModel( | ||
| patch_size=(1, 2, 2), | ||
| num_attention_heads=2, | ||
| attention_head_dim=12, | ||
| in_channels=16, | ||
| out_channels=16, | ||
| text_dim=32, | ||
| freq_dim=256, | ||
| ffn_dim=32, | ||
| num_layers=2, | ||
| cross_attn_norm=True, | ||
| qk_norm="rms_norm_across_heads", | ||
| rope_max_seq_len=32, | ||
| # Use from_pretrained with a tiny model to ensure proper dtype handling | ||
| # This ensures _keep_in_fp32_modules and _skip_layerwise_casting_patterns are respected | ||
| transformer = WanTransformer3DModel.from_pretrained( | ||
| "Kaixuanliu/tiny-random-wan-transformer", | ||
|
||
| torch_dtype=dtype | ||
| ) | ||
|
|
||
| torch.manual_seed(0) | ||
| transformer_2 = WanTransformer3DModel( | ||
| patch_size=(1, 2, 2), | ||
| num_attention_heads=2, | ||
| attention_head_dim=12, | ||
| in_channels=16, | ||
| out_channels=16, | ||
| text_dim=32, | ||
| freq_dim=256, | ||
| ffn_dim=32, | ||
| num_layers=2, | ||
| cross_attn_norm=True, | ||
| qk_norm="rms_norm_across_heads", | ||
| rope_max_seq_len=32, | ||
| transformer_2 = WanTransformer3DModel.from_pretrained( | ||
| "Kaixuanliu/tiny-random-wan-transformer", | ||
|
||
| torch_dtype=dtype | ||
| ) | ||
|
|
||
| components = { | ||
|
|
@@ -155,6 +138,44 @@ def test_inference(self): | |
| def test_attention_slicing_forward_pass(self): | ||
| pass | ||
|
|
||
| @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") | ||
| @require_accelerator | ||
| def test_save_load_float16(self, expected_max_diff=1e-2): | ||
|
||
| # Use get_dummy_components with dtype parameter instead of converting components | ||
| components = self.get_dummy_components(dtype=torch.float16) | ||
| pipe = self.pipeline_class(**components) | ||
| for component in pipe.components.values(): | ||
| if hasattr(component, "set_default_attn_processor"): | ||
| component.set_default_attn_processor() | ||
| pipe.to(torch_device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| inputs = self.get_dummy_inputs(torch_device) | ||
| output = pipe(**inputs)[0] | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| pipe.save_pretrained(tmpdir) | ||
| pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) | ||
| for component in pipe_loaded.components.values(): | ||
| if hasattr(component, "set_default_attn_processor"): | ||
| component.set_default_attn_processor() | ||
| pipe_loaded.to(torch_device) | ||
| pipe_loaded.set_progress_bar_config(disable=None) | ||
|
|
||
| for name, component in pipe_loaded.components.items(): | ||
| if hasattr(component, "dtype"): | ||
| self.assertTrue( | ||
| component.dtype == torch.float16, | ||
| f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", | ||
| ) | ||
|
|
||
| inputs = self.get_dummy_inputs(torch_device) | ||
| output_loaded = pipe_loaded(**inputs)[0] | ||
| max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() | ||
| self.assertLess( | ||
| max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." | ||
| ) | ||
|
|
||
| def test_save_load_optional_components(self, expected_max_difference=1e-4): | ||
| optional_component = "transformer" | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls refer to L246-L256 (Sorry I only found Chinese version for this explanation). Using
torch.Tensor.tomethod will convert all weights, while usingtorch_dtypeparameter withfrom_pretrainedwill preserve layers in_keep_in_fp32_modules. For wan models, all components of pipe will be fp16 dtype while it is not the case for pipe_loaded. Here I overridetest_save_load_float16function seperately for wan models.