Skip to content

Commit a0ad778

Browse files
authored
Support diffusion model saving (#1519)
1 parent e14e7db commit a0ad778

File tree

4 files changed

+144
-13
lines changed

4 files changed

+144
-13
lines changed

auto_round/compressors/diffusion/compressor.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from collections import defaultdict
1617
from copy import deepcopy
1718
from typing import Union
@@ -22,6 +23,7 @@
2223
from auto_round.compressors.base import BaseCompressor
2324
from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader
2425
from auto_round.compressors.utils import block_forward
26+
from auto_round.formats import OutputFormat
2527
from auto_round.logger import logger
2628
from auto_round.schemes import QuantizationScheme
2729
from auto_round.utils import (
@@ -112,13 +114,7 @@ def __init__(
112114
device_map = 0
113115
self._set_device(device_map)
114116

115-
if isinstance(model, str):
116-
pipe, model = diffusion_load_model(model, platform=platform, device=self.device, model_dtype=model_dtype)
117-
elif isinstance(model, pipeline_utils.DiffusionPipeline):
118-
pipe = model
119-
model = pipe.transformer
120-
else:
121-
raise ValueError(f"Only support str or DiffusionPipeline class for model, but get {type(model)}")
117+
pipe, model = diffusion_load_model(model, platform=platform, device=self.device, model_dtype=model_dtype)
122118

123119
self.model = model
124120
self.pipe = pipe
@@ -373,6 +369,33 @@ def calib(self, nsamples, bs):
373369

374370
# torch.cuda.empty_cache()
375371

372+
def _get_save_folder_name(self, format: OutputFormat) -> str:
373+
"""Generates the save folder name based on the provided format string.
374+
375+
If there are multiple formats to handle, the function creates a subfolder
376+
named after the format string with special characters replaced. If there's
377+
only one format, it returns the original output directory directly.
378+
379+
Args:
380+
format_str (str): The format identifier (e.g., 'gguf:q2_k_s').
381+
382+
Returns:
383+
str: The path to the folder where results should be saved.
384+
"""
385+
# Replace special characters to make the folder name filesystem-safe
386+
sanitized_format = format.get_backend_name().replace(":", "-").replace("_", "-")
387+
388+
# Use a subfolder only if there are multiple formats
389+
if len(self.formats) > 1:
390+
return (
391+
os.path.join(self.orig_output_dir, sanitized_format, "transformer")
392+
if self.is_immediate_saving
393+
else os.path.join(self.orig_output_dir, sanitized_format, "transformer")
394+
)
395+
396+
# if use is_immediate_saving, we need to save model in self.orig_output_dir/transformer folder
397+
return os.path.join(self.orig_output_dir, "transformer") if self.is_immediate_saving else self.orig_output_dir
398+
376399
def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs):
377400
"""Save the quantized model to the specified output directory in the specified format.
378401
@@ -385,5 +408,27 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
385408
Returns:
386409
object: The compressed model object.
387410
"""
388-
compressed_model = super().save_quantized(output_dir=output_dir, format=format, inplace=inplace, **kwargs)
411+
if output_dir is None:
412+
return super().save_quantized(output_dir, format=format, inplace=inplace, **kwargs)
413+
414+
compressed_model = None
415+
for name in self.pipe.components.keys():
416+
val = getattr(self.pipe, name)
417+
sub_module_path = (
418+
os.path.join(output_dir, name) if os.path.basename(os.path.normpath(output_dir)) != name else output_dir
419+
)
420+
if (
421+
hasattr(val, "config")
422+
and hasattr(val.config, "_name_or_path")
423+
and val.config._name_or_path == self.model.config._name_or_path
424+
):
425+
compressed_model = super().save_quantized(
426+
output_dir=sub_module_path if not self.is_immediate_saving else output_dir,
427+
format=format,
428+
inplace=inplace,
429+
**kwargs,
430+
)
431+
elif val is not None and hasattr(val, "save_pretrained"):
432+
val.save_pretrained(sub_module_path)
433+
self.pipe.config.save_pretrained(output_dir)
389434
return compressed_model

auto_round/eval/evaluation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,15 @@ def evaluate_diffusion_model(args, autoround=None, model=None, pipe=None):
100100

101101
import torch
102102

103-
from auto_round.utils import detect_device, get_model_dtype, logger
103+
from auto_round.utils import detect_device, get_model_dtype, logger, unsupported_meta_device
104104

105105
# Prepare inference pipeline
106106
if pipe is None:
107+
if model is not None and unsupported_meta_device(model):
108+
logger.error(
109+
"Quantized model is meta and diffusers doesn't support loading auto-round quantized model now. Exit."
110+
)
111+
exit(0)
107112
pipe = autoround.pipe
108113
pipe.to(model.dtype)
109114
pipe.transformer = model

auto_round/utils/model.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,8 @@ def diffusion_load_model(
527527
model_dtype: str = None,
528528
**kwargs,
529529
):
530+
from functools import partial
531+
530532
from auto_round.utils.common import LazyImport
531533
from auto_round.utils.device import get_device_and_parallelism
532534

@@ -543,12 +545,68 @@ def diffusion_load_model(
543545
torch_dtype = torch.bfloat16
544546

545547
pipelines = LazyImport("diffusers.pipelines")
548+
if isinstance(pretrained_model_name_or_path, str):
549+
if torch_dtype == "auto":
550+
torch_dtype = {}
551+
model_index = os.path.join(pretrained_model_name_or_path, "model_index.json")
552+
with open(model_index, "r", encoding="utf-8") as file:
553+
config = json.load(file)
554+
for k, v in config.items():
555+
component_folder = os.path.join(pretrained_model_name_or_path, k)
556+
if isinstance(v, list) and os.path.exists(os.path.join(component_folder, "config.json")):
557+
component_folder = os.path.join(pretrained_model_name_or_path, k)
558+
with open(os.path.join(component_folder, "config.json"), "r", encoding="utf-8") as file:
559+
component_config = json.load(file)
560+
torch_dtype[k] = component_config.get("torch_dtype", "auto")
561+
562+
pipe = pipelines.auto_pipeline.AutoPipelineForText2Image.from_pretrained(
563+
pretrained_model_name_or_path, torch_dtype=torch_dtype
564+
)
565+
pipe_config = pipe.load_config(pretrained_model_name_or_path)
566+
567+
elif isinstance(pretrained_model_name_or_path, pipelines.pipeline_utils.DiffusionPipeline):
568+
pipe = pretrained_model_name_or_path
569+
pipe_config = pipe.load_config(pipe.config["_name_or_path"])
570+
571+
else:
572+
raise ValueError(
573+
f"Only support str or DiffusionPipeline class for model, but get {type(pretrained_model_name_or_path)}"
574+
)
575+
576+
# add missing key
577+
for k, v in pipe_config.items():
578+
if k not in pipe.config:
579+
pipe.config[k] = v
546580

547-
pipe = pipelines.auto_pipeline.AutoPipelineForText2Image.from_pretrained(
548-
pretrained_model_name_or_path, torch_dtype=torch_dtype
549-
)
550581
pipe = _to_model_dtype(pipe, model_dtype)
551582
model = pipe.transformer
583+
584+
def config_save_pretrained(config, file_name, save_directory):
585+
if os.path.isfile(save_directory):
586+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
587+
os.makedirs(save_directory, exist_ok=True)
588+
output_config_file = os.path.join(save_directory, file_name)
589+
590+
config_dict = dict(config)
591+
if file_name == "config.json" and hasattr(model.config, "quantization_config"):
592+
config_dict["quantization_config"] = model.config.quantization_config
593+
594+
with open(output_config_file, "w", encoding="utf-8") as writer:
595+
writer.write(json.dumps(config_dict, indent=2, sort_keys=True) + "\n")
596+
597+
# meta model uses model.config.save_pretrained for config saving
598+
setattr(model.config, "save_pretrained", partial(config_save_pretrained, model.config, "config.json"))
599+
setattr(pipe.config, "save_pretrained", partial(config_save_pretrained, pipe.config, "model_index.json"))
600+
601+
def model_save_pretrained(model, save_directory, **kwargs):
602+
super(model.__class__, model).save_pretrained(save_directory, **kwargs)
603+
if hasattr(model.config, "quantization_config"):
604+
model.config["quantization_config"] = model.config.quantization_config
605+
with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as writer:
606+
writer.write(json.dumps(dict(model.config), indent=2, sort_keys=True) + "\n")
607+
608+
# non-meta model uses model.save_pretrained for model and config saving
609+
setattr(model, "save_pretrained", partial(model_save_pretrained, model))
552610
return pipe, model.to(device)
553611

554612

test/test_cpu/models/test_diffusion.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import os
12
import shutil
23

34
import pytest
5+
import torch
46
from packaging import version
57

68
from auto_round import AutoRound
@@ -16,11 +18,32 @@ def setup_flux():
1618
from diffusers import AutoPipelineForText2Image
1719

1820
model_name = flux_name_or_path
19-
pipe = AutoPipelineForText2Image.from_pretrained(model_name)
21+
# use bf16 to reduce the saved model size
22+
pipe = AutoPipelineForText2Image.from_pretrained(model_name, torch_dtype=torch.bfloat16)
2023
output_dir = "./tmp/test_quantized_flux"
2124
return pipe, output_dir
2225

2326

27+
@pytest.mark.skipif(
28+
transformers_version >= version.parse("5.0.0"),
29+
reason="cannot import name 'MT5Tokenizer' from 'transformers', https://github.com/huggingface/diffusers/issues/13035",
30+
)
31+
def test_flux_saving(setup_flux):
32+
pipe, output_dir = setup_flux
33+
autoround = AutoRound(
34+
pipe,
35+
tokenizer=None,
36+
scheme="W4A16",
37+
iters=0,
38+
num_inference_steps=2,
39+
disable_opt_rtn=True,
40+
)
41+
autoround.quantize_and_save(output_dir)
42+
assert os.path.exists(os.path.join(output_dir, "model_index.json"))
43+
assert os.path.exists(os.path.join(output_dir, "transformer", "quantization_config.json"))
44+
shutil.rmtree(output_dir, ignore_errors=True)
45+
46+
2447
@pytest.mark.skipif(
2548
transformers_version >= version.parse("5.0.0"),
2649
reason="cannot import name 'MT5Tokenizer' from 'transformers', https://github.com/huggingface/diffusers/issues/13035",

0 commit comments

Comments
 (0)