|
79 | 79 | "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
80 | 80 | "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
81 | 81 | "animatediff_rgb": "controlnet_cond_embedding.weight",
|
82 |
| - "flux": "double_blocks.0.img_attn.norm.key_norm.scale", |
| 82 | + "flux": [ |
| 83 | + "double_blocks.0.img_attn.norm.key_norm.scale", |
| 84 | + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", |
| 85 | + ], |
83 | 86 | }
|
84 | 87 |
|
85 | 88 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
258 | 261 | "timestep_spacing": "leading",
|
259 | 262 | }
|
260 | 263 |
|
261 |
| -LDM_VAE_KEY = "first_stage_model." |
| 264 | +LDM_VAE_KEYS = ["first_stage_model.", "vae."] |
262 | 265 | LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
263 | 266 | PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
264 | 267 | LDM_UNET_KEY = "model.diffusion_model."
|
|
267 | 270 | "cond_stage_model.transformer.",
|
268 | 271 | "conditioner.embedders.0.transformer.",
|
269 | 272 | ]
|
270 |
| -OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." |
271 | 273 | LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
272 | 274 | SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
|
273 | 275 |
|
@@ -521,8 +523,10 @@ def infer_diffusers_model_type(checkpoint):
|
521 | 523 | else:
|
522 | 524 | model_type = "animatediff_v3"
|
523 | 525 |
|
524 |
| - elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint: |
525 |
| - if "guidance_in.in_layer.bias" in checkpoint: |
| 526 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): |
| 527 | + if any( |
| 528 | + g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] |
| 529 | + ): |
526 | 530 | model_type = "flux-dev"
|
527 | 531 | else:
|
528 | 532 | model_type = "flux-schnell"
|
@@ -1181,7 +1185,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
1181 | 1185 | # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
|
1182 | 1186 | vae_state_dict = {}
|
1183 | 1187 | keys = list(checkpoint.keys())
|
1184 |
| - vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else "" |
| 1188 | + vae_key = "" |
| 1189 | + for ldm_vae_key in LDM_VAE_KEYS: |
| 1190 | + if any(k.startswith(ldm_vae_key) for k in keys): |
| 1191 | + vae_key = ldm_vae_key |
| 1192 | + |
1185 | 1193 | for key in keys:
|
1186 | 1194 | if key.startswith(vae_key):
|
1187 | 1195 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
@@ -1894,6 +1902,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
1894 | 1902 |
|
1895 | 1903 | def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
1896 | 1904 | converted_state_dict = {}
|
| 1905 | + keys = list(checkpoint.keys()) |
| 1906 | + for k in keys: |
| 1907 | + if "model.diffusion_model." in k: |
| 1908 | + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) |
1897 | 1909 |
|
1898 | 1910 | num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
1899 | 1911 | num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
|
0 commit comments