Skip to content

Commit b12c7f8

Browse files
DN6yiyixuxu
authored andcommitted
[Single File] Support loading Comfy UI Flux checkpoints (#9243)
update
1 parent 06f3671 commit b12c7f8

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/diffusers/loaders/single_file_utils.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@
7979
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
8080
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
8181
"animatediff_rgb": "controlnet_cond_embedding.weight",
82-
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
82+
"flux": [
83+
"double_blocks.0.img_attn.norm.key_norm.scale",
84+
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
85+
],
8386
}
8487

8588
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -258,7 +261,7 @@
258261
"timestep_spacing": "leading",
259262
}
260263

261-
LDM_VAE_KEY = "first_stage_model."
264+
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
262265
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
263266
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
264267
LDM_UNET_KEY = "model.diffusion_model."
@@ -267,7 +270,6 @@
267270
"cond_stage_model.transformer.",
268271
"conditioner.embedders.0.transformer.",
269272
]
270-
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
271273
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
272274
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
273275

@@ -521,8 +523,10 @@ def infer_diffusers_model_type(checkpoint):
521523
else:
522524
model_type = "animatediff_v3"
523525

524-
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
525-
if "guidance_in.in_layer.bias" in checkpoint:
526+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
527+
if any(
528+
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
529+
):
526530
model_type = "flux-dev"
527531
else:
528532
model_type = "flux-schnell"
@@ -1181,7 +1185,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
11811185
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
11821186
vae_state_dict = {}
11831187
keys = list(checkpoint.keys())
1184-
vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else ""
1188+
vae_key = ""
1189+
for ldm_vae_key in LDM_VAE_KEYS:
1190+
if any(k.startswith(ldm_vae_key) for k in keys):
1191+
vae_key = ldm_vae_key
1192+
11851193
for key in keys:
11861194
if key.startswith(vae_key):
11871195
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
@@ -1894,6 +1902,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
18941902

18951903
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
18961904
converted_state_dict = {}
1905+
keys = list(checkpoint.keys())
1906+
for k in keys:
1907+
if "model.diffusion_model." in k:
1908+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
18971909

18981910
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
18991911
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401

0 commit comments

Comments
 (0)