@@ -527,6 +527,8 @@ def diffusion_load_model(
527527 model_dtype : str = None ,
528528 ** kwargs ,
529529):
530+ from functools import partial
531+
530532 from auto_round .utils .common import LazyImport
531533 from auto_round .utils .device import get_device_and_parallelism
532534
@@ -543,12 +545,68 @@ def diffusion_load_model(
543545 torch_dtype = torch .bfloat16
544546
545547 pipelines = LazyImport ("diffusers.pipelines" )
548+ if isinstance (pretrained_model_name_or_path , str ):
549+ if torch_dtype == "auto" :
550+ torch_dtype = {}
551+ model_index = os .path .join (pretrained_model_name_or_path , "model_index.json" )
552+ with open (model_index , "r" , encoding = "utf-8" ) as file :
553+ config = json .load (file )
554+ for k , v in config .items ():
555+ component_folder = os .path .join (pretrained_model_name_or_path , k )
556+ if isinstance (v , list ) and os .path .exists (os .path .join (component_folder , "config.json" )):
557+ component_folder = os .path .join (pretrained_model_name_or_path , k )
558+ with open (os .path .join (component_folder , "config.json" ), "r" , encoding = "utf-8" ) as file :
559+ component_config = json .load (file )
560+ torch_dtype [k ] = component_config .get ("torch_dtype" , "auto" )
561+
562+ pipe = pipelines .auto_pipeline .AutoPipelineForText2Image .from_pretrained (
563+ pretrained_model_name_or_path , torch_dtype = torch_dtype
564+ )
565+ pipe_config = pipe .load_config (pretrained_model_name_or_path )
566+
567+ elif isinstance (pretrained_model_name_or_path , pipelines .pipeline_utils .DiffusionPipeline ):
568+ pipe = pretrained_model_name_or_path
569+ pipe_config = pipe .load_config (pipe .config ["_name_or_path" ])
570+
571+ else :
572+ raise ValueError (
573+ f"Only support str or DiffusionPipeline class for model, but get { type (pretrained_model_name_or_path )} "
574+ )
575+
576+ # add missing key
577+ for k , v in pipe_config .items ():
578+ if k not in pipe .config :
579+ pipe .config [k ] = v
546580
547- pipe = pipelines .auto_pipeline .AutoPipelineForText2Image .from_pretrained (
548- pretrained_model_name_or_path , torch_dtype = torch_dtype
549- )
550581 pipe = _to_model_dtype (pipe , model_dtype )
551582 model = pipe .transformer
583+
584+ def config_save_pretrained (config , file_name , save_directory ):
585+ if os .path .isfile (save_directory ):
586+ raise AssertionError (f"Provided path ({ save_directory } ) should be a directory, not a file" )
587+ os .makedirs (save_directory , exist_ok = True )
588+ output_config_file = os .path .join (save_directory , file_name )
589+
590+ config_dict = dict (config )
591+ if file_name == "config.json" and hasattr (model .config , "quantization_config" ):
592+ config_dict ["quantization_config" ] = model .config .quantization_config
593+
594+ with open (output_config_file , "w" , encoding = "utf-8" ) as writer :
595+ writer .write (json .dumps (config_dict , indent = 2 , sort_keys = True ) + "\n " )
596+
597+ # meta model uses model.config.save_pretrained for config saving
598+ setattr (model .config , "save_pretrained" , partial (config_save_pretrained , model .config , "config.json" ))
599+ setattr (pipe .config , "save_pretrained" , partial (config_save_pretrained , pipe .config , "model_index.json" ))
600+
601+ def model_save_pretrained (model , save_directory , ** kwargs ):
602+ super (model .__class__ , model ).save_pretrained (save_directory , ** kwargs )
603+ if hasattr (model .config , "quantization_config" ):
604+ model .config ["quantization_config" ] = model .config .quantization_config
605+ with open (os .path .join (save_directory , "config.json" ), "w" , encoding = "utf-8" ) as writer :
606+ writer .write (json .dumps (dict (model .config ), indent = 2 , sort_keys = True ) + "\n " )
607+
608+ # non-meta model uses model.save_pretrained for model and config saving
609+ setattr (model , "save_pretrained" , partial (model_save_pretrained , model ))
552610 return pipe , model .to (device )
553611
554612
0 commit comments