Skip to content

File tree

4 files changed

+6
-10
lines changed

4 files changed

+6
-10
lines changed
 

‎src/maxdiffusion/models/flux/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import jax
66
from jax.typing import DTypeLike
7-
import torch # need for torch 2 jax
87
from chex import Array
98
from flax.traverse_util import flatten_dict, unflatten_dict
109
from huggingface_hub import hf_hub_download

‎src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from maxdiffusion import max_logging
44
from huggingface_hub import hf_hub_download
55
from safetensors import safe_open
6-
from flax.traverse_util import flatten_dict, unflatten_dict
6+
from flax.traverse_util import unflatten_dict
77
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
88

99

@@ -12,7 +12,7 @@ def _tuple_str_to_int(in_tuple):
1212
for item in in_tuple:
1313
try:
1414
out_list.append(int(item))
15-
except:
15+
except ValueError:
1616
out_list.append(item)
1717
return tuple(out_list)
1818

‎src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from flax import nnx
2424
import numpy as np
2525
import unittest
26-
import pytest
2726
from absl.testing import absltest
2827
from skimage.metrics import structural_similarity as ssim
2928
from ..models.wan.autoencoder_kl_wan import (
@@ -172,7 +171,7 @@ def test_wanrms_norm(self):
172171
dummy_input = jnp.ones(input_shape)
173172
output = wanrms_norm(dummy_input)
174173
output_np = np.array(output)
175-
assert np.allclose(output_np, torch_output_np) == True
174+
assert np.allclose(output_np, torch_output_np) is True
176175

177176
# --- Test Case 2: images == False ---
178177
model = TorchWanRMS_norm(dim, images=False)
@@ -186,7 +185,7 @@ def test_wanrms_norm(self):
186185
dummy_input = jnp.ones(input_shape)
187186
output = wanrms_norm(dummy_input)
188187
output_np = np.array(output)
189-
assert np.allclose(output_np, torch_output_np) == True
188+
assert np.allclose(output_np, torch_output_np) is True
190189

191190
def test_zero_padded_conv(self):
192191

@@ -235,8 +234,6 @@ def test_wan_resample(self):
235234
w = 720
236235
mode = "downsample2d"
237236
input_shape = (batch, dim, t, h, w)
238-
expected_output_shape = (1, dim, 1, 240, 360)
239-
# output dim should be (1, 96, 1, 480, 720)
240237
dummy_input = torch.ones(input_shape)
241238
torch_wan_resample = TorchWanResample(dim=dim, mode=mode)
242239
torch_output = torch_wan_resample(dummy_input)
@@ -426,7 +423,7 @@ def vae_encode(video, wan_vae, vae_cache, key):
426423
rngs = nnx.Rngs(key)
427424
wan_vae = AutoencoderKLWan.from_config(pretrained_model_name_or_path, subfolder="vae", rngs=rngs)
428425
vae_cache = AutoencoderKLWanCache(wan_vae)
429-
video_path, fps = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4", 8
426+
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
430427
video = load_video(video_path)
431428

432429
vae_scale_factor_spatial = 2 ** len(wan_vae.temperal_downsample)

‎src/maxdiffusion/utils/loading_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Callable, List, Optional, Tuple, Union
2+
from typing import Callable, List, Optional, Union
33

44
import PIL.Image
55
import PIL.ImageOps

0 commit comments

Comments
 (0)
Please sign in to comment.