Skip to content

Commit f3cf5ff

Browse files
fix dtype bugs (#72)
1 parent b0bf5ee commit f3cf5ff

File tree

9 files changed

+43
-11
lines changed

9 files changed

+43
-11
lines changed

diffsynth_engine/models/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch.nn as nn
33
from contextlib import contextmanager
44

5-
65
# mofified from transformers.modeling_utils
76
TORCH_INIT_FUNCTIONS = {
87
"uniform_": nn.init.uniform_,

diffsynth_engine/models/vae/vae.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def __init__(
167167
self.conv_norm_out = nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6, device=device, dtype=dtype)
168168
self.conv_act = nn.SiLU()
169169
self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1, device=device, dtype=dtype)
170+
self.device = device
171+
self.dtype = dtype
170172

171173
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
172174
original_dtype = sample.dtype
@@ -277,6 +279,8 @@ def __init__(
277279
self.conv_norm_out = nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6, device=device, dtype=dtype)
278280
self.conv_act = nn.SiLU()
279281
self.conv_out = nn.Conv2d(512, 2 * latent_channels, kernel_size=3, padding=1, device=device, dtype=dtype)
282+
self.device = device
283+
self.dtype = dtype
280284

281285
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
282286
original_dtype = sample.dtype

diffsynth_engine/pipelines/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
99
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
1010
from diffsynth_engine.utils import logging
11+
from diffsynth_engine.utils.platform import empty_cache
1112

1213
logger = logging.get_logger(__name__)
1314

@@ -144,15 +145,17 @@ def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
144145
return noise
145146

146147
def encode_image(self, image: torch.Tensor) -> torch.Tensor:
148+
image = image.to(self.device, self.vae_encoder.dtype)
147149
latents = self.vae_encoder(
148150
image, tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
149151
)
150152
return latents
151153

152154
def decode_image(self, latent: torch.Tensor) -> torch.Tensor:
153155
vae_dtype = self.vae_decoder.conv_in.weight.dtype
156+
latent = latent.to(self.device, vae_dtype)
154157
image = self.vae_decoder(
155-
latent.to(vae_dtype), tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
158+
latent, tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
156159
)
157160
return image
158161

@@ -233,7 +236,7 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
233236
return
234237
if self.offload_mode == "sequential_cpu_offload":
235238
# fresh the cuda cache
236-
torch.cuda.empty_cache()
239+
empty_cache()
237240
return
238241

239242
# offload unnecessary models to cpu
@@ -248,4 +251,4 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
248251
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != self.device:
249252
model.to(self.device)
250253
# fresh the cuda cache
251-
torch.cuda.empty_cache()
254+
empty_cache()

diffsynth_engine/pipelines/flux_image.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from diffsynth_engine.utils import logging
2626
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
2727
from diffsynth_engine.utils.download import fetch_model
28+
from diffsynth_engine.utils.platform import empty_cache
2829

2930
logger = logging.get_logger(__name__)
3031

@@ -546,6 +547,7 @@ def predict_noise(
546547
current_step=current_step,
547548
total_step=total_step,
548549
)
550+
self.load_models_to_device(["dit"])
549551
noise_pred = self.dit(
550552
hidden_states=latents,
551553
timestep=timestep,
@@ -570,15 +572,14 @@ def prepare_latents(
570572
):
571573
# Prepare scheduler
572574
if input_image is not None:
575+
self.load_models_to_device(["vae_encoder"])
573576
total_steps = num_inference_steps
574577
sigmas, timesteps = self.noise_scheduler.schedule(
575578
total_steps, mu=mu, sigma_min=1 / total_steps, sigma_max=1.0
576579
)
577580
t_start = max(total_steps - int(num_inference_steps * denoising_strength), 1)
578581
sigma_start, sigmas = sigmas[t_start - 1], sigmas[t_start - 1 :]
579582
timesteps = timesteps[t_start - 1 :]
580-
581-
self.load_models_to_device(["vae_encoder"])
582583
noise = latents
583584
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.dtype)
584585
latents = self.encode_image(image)
@@ -593,6 +594,7 @@ def prepare_latents(
593594
return init_latents, latents, sigmas, timesteps
594595

595596
def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, height: int, width: int):
597+
self.load_models_to_device(["vae_encoder"])
596598
if mask is None:
597599
image = image.resize((width, height))
598600
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
@@ -637,6 +639,8 @@ def predict_multicontrolnet(
637639
total_step: int,
638640
):
639641
double_block_output_results, single_block_output_results = None, None
642+
if len(controlnet_params) > 0:
643+
self.load_models_to_device([])
640644
for param in controlnet_params:
641645
current_scale = param.scale
642646
if not (
@@ -645,6 +649,9 @@ def predict_multicontrolnet(
645649
# if current_step is not in the control range
646650
# skip thie controlnet
647651
continue
652+
if self.offload_mode == "sequential_cpu_offload" or self.offload_mode == "cpu_offload":
653+
empty_cache()
654+
param.model.to(self.device)
648655
double_block_output, single_block_output = param.model(
649656
latents,
650657
param.image,
@@ -656,6 +663,9 @@ def predict_multicontrolnet(
656663
image_ids,
657664
text_ids,
658665
)
666+
if self.offload_mode == "sequential_cpu_offload" or self.offload_mode == "cpu_offload":
667+
empty_cache()
668+
param.model.to("cpu")
659669
double_block_output_results = accumulate(double_block_output_results, double_block_output)
660670
single_block_output_results = accumulate(single_block_output_results, single_block_output)
661671
return double_block_output_results, single_block_output_results
@@ -741,7 +751,7 @@ def __call__(
741751
)
742752

743753
# Denoise
744-
self.load_models_to_device(["dit"])
754+
self.load_models_to_device([])
745755
for i, timestep in enumerate(tqdm(timesteps)):
746756
timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
747757
noise_pred = self.predict_noise_with_cfg(

diffsynth_engine/tools/flux_inpainting_tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ def __init__(
1313
dtype: torch.dtype = torch.bfloat16,
1414
offload_mode: Optional[str] = None,
1515
):
16-
self.pipe = FluxImagePipeline.from_pretrained(flux_model_path, device=device, offload_mode=offload_mode)
16+
self.pipe = FluxImagePipeline.from_pretrained(
17+
flux_model_path, device=device, offload_mode=offload_mode, dtype=dtype
18+
)
1719
self.pipe.load_loras(lora_list)
1820
self.controlnet = FluxControlNet.from_pretrained(
1921
fetch_model(

diffsynth_engine/tools/flux_outpainting_tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ def __init__(
1313
dtype: torch.dtype = torch.bfloat16,
1414
offload_mode: Optional[str] = None,
1515
):
16-
self.pipe = FluxImagePipeline.from_pretrained(flux_model_path, device=device, offload_mode=offload_mode)
16+
self.pipe = FluxImagePipeline.from_pretrained(
17+
flux_model_path, device=device, offload_mode=offload_mode, dtype=dtype
18+
)
1719
self.pipe.load_loras(lora_list)
1820
self.controlnet = FluxControlNet.from_pretrained(
1921
fetch_model(

diffsynth_engine/tools/flux_reference_tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
offload_mode: Optional[str] = None,
2020
):
2121
self.pipe: FluxImagePipeline = FluxImagePipeline.from_pretrained(
22-
flux_model_path, load_text_encoder=load_text_encoder, device=device, offload_mode=offload_mode
22+
flux_model_path, load_text_encoder=load_text_encoder, device=device, offload_mode=offload_mode, dtype=dtype
2323
)
2424
self.pipe.load_loras(lora_list)
2525
redux_model_path = fetch_model("muse/flux1-redux-dev", path="flux1-redux-dev.safetensors", revision="v1")

diffsynth_engine/tools/flux_replace_tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
offload_mode: Optional[str] = None,
2121
):
2222
self.pipe: FluxImagePipeline = FluxImagePipeline.from_pretrained(
23-
flux_model_path, load_text_encoder=load_text_encoder, device=device, offload_mode=offload_mode
23+
flux_model_path, load_text_encoder=load_text_encoder, device=device, offload_mode=offload_mode, dtype=dtype
2424
)
2525
self.pipe.load_loras(lora_list)
2626
redux_model_path = fetch_model("muse/flux1-redux-dev", path="flux1-redux-dev.safetensors", revision="v1")

diffsynth_engine/utils/platform.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
import gc
3+
4+
# 存放跨平台的工具类
5+
6+
7+
def empty_cache():
8+
if torch.cuda.is_available():
9+
torch.cuda.empty_cache()
10+
if torch.mps.is_available():
11+
torch.mps.empty_cache()
12+
gc.collect()

0 commit comments

Comments
 (0)