2020 AutoencoderConfig ,
2121 CLIPTextModelConfig ,
2222 SD3_2b ,
23+ SD3_8b ,
2324 VAEDecoderConfig ,
2425 VAEEncoderConfig ,
2526)
5051 "argmaxinc/mlx-FLUX.1-dev" : "flux1-dev.safetensors" ,
5152 "vae" : "ae.safetensors" ,
5253 },
54+ "argmaxinc/mlx-stable-diffusion-3.5-large" : {
55+ "argmaxinc/mlx-stable-diffusion-3.5-large" : "sd3.5_large.safetensors" ,
56+ "vae" : "sd3.5_large.safetensors" ,
57+ },
5358}
5459_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
5560_MODELS = {
8388 "vae_encoder" : "encoder." ,
8489 "vae_decoder" : "decoder." ,
8590 },
91+ "argmaxinc/mlx-stable-diffusion-3.5-large" : {
92+ "vae_encoder" : "first_stage_model.encoder." ,
93+ "vae_decoder" : "first_stage_model.decoder." ,
94+ },
95+ }
96+
97+ _CONFIG = {
98+ "argmaxinc/mlx-stable-diffusion-3-medium" : SD3_2b ,
99+ "argmaxinc/mlx-FLUX.1-schnell" : FLUX_SCHNELL ,
100+ "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized" : FLUX_SCHNELL ,
101+ "argmaxinc/mlx-FLUX.1-dev" : FLUX_SCHNELL ,
102+ "argmaxinc/mlx-stable-diffusion-3.5-large" : SD3_8b ,
86103}
87104
88105_FLOAT16 = mx .bfloat16
89106
90107DEPTH = {
91108 "argmaxinc/mlx-stable-diffusion-3-medium" : 24 ,
92- "sd3-8b-unreleased " : 38 ,
109+ "argmaxinc/mlx-stable-diffusion-3.5-large " : 38 ,
93110}
94111MAX_LATENT_RESOLUTION = {
95112 "argmaxinc/mlx-stable-diffusion-3-medium" : 96 ,
96- "sd3-8b-unreleased " : 192 ,
113+ "argmaxinc/mlx-stable-diffusion-3.5-large " : 192 ,
97114}
98115
99116LOCAl_SD3_CKPT = None
@@ -321,6 +338,14 @@ def mmdit_state_dict_adjustments(state_dict, prefix=""):
321338 for k , v in state_dict .items ()
322339 }
323340
341+ # Remap qk_norm
342+ state_dict = {
343+ k .replace (".attn.ln_q." , ".qk_norm.q_norm." ): v for k , v in state_dict .items ()
344+ }
345+ state_dict = {
346+ k .replace (".attn.ln_k." , ".qk_norm.k_norm." ): v for k , v in state_dict .items ()
347+ }
348+
324349 # Split qkv proj and rename:
325350 # *transformer_block.attn.qkv.{weigth/bias} -> transformer_block.attn.{q/k/v}_proj.{weigth/bias}
326351 # *transformer_block.attn.proj.{weigth/bias} -> transformer_block.attn.o_proj.{weight/bias}
@@ -347,6 +372,9 @@ def mmdit_state_dict_adjustments(state_dict, prefix=""):
347372 # Filter out VAE Decoder related tensors
348373 state_dict = {k : v for k , v in state_dict .items () if "decoder." not in k }
349374
375+ # Filter out VAE Encoder related tensors
376+ state_dict = {k : v for k , v in state_dict .items () if "encoder." not in k }
377+
350378 # Filter out k_proj.bias related tensors
351379 state_dict = {k : v for k , v in state_dict .items () if "k_proj.bias" not in k }
352380
@@ -676,7 +704,7 @@ def load_mmdit(
676704 """Load the MM-DiT model from the checkpoint file."""
677705 """only_modulation_dict: Only returns the modulation dictionary"""
678706 dtype = _FLOAT16 if float16 else mx .float32
679- config = SD3_2b
707+ config = _CONFIG [ key ]
680708 config .low_memory_mode = low_memory_mode
681709 model = MMDiT (config )
682710
0 commit comments