Skip to content

Commit 65118fa

Browse files
committed
trying
1 parent 5e675cb commit 65118fa

16 files changed

+149
-148
lines changed
2.49 KB
Binary file not shown.
5.75 KB
Binary file not shown.

code/main_simple_inversion.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## This code is for genearting basic image sample, adjusted from local_gradio.py
22

3-
from pipeline_rf_adjusting import RectifiedInversableFlowPipeline
3+
from pipeline_rf import RectifiedInversableFlowPipeline
44

55
import torch
66
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
@@ -50,8 +50,8 @@ def set_and_generate_image_then_reverse(seed, prompt, inversion_prompt, randomiz
5050
use_random_initial_noise=False,
5151
decoder_inv_steps=10,
5252
forward_steps=100,
53-
tuning_steps=50,
54-
pnp_adjust=False,
53+
tuning_steps=10,
54+
pnp_adjust=True,
5555
)
5656

5757
print(f"TOT of inversion {torch.mean((recon_latents-original_latents)**2)/torch.norm(original_latents)**2}")
@@ -97,7 +97,8 @@ def set_and_generate_image_then_reverse(seed, prompt, inversion_prompt, randomiz
9797
# print(f'Shapiro-Wilk test for Reconstructed Latents: W={stat_recon:.8f}, p-value={p_value_recon:.8f}')
9898

9999
# Section : Check with plot distribution
100-
plot_distribution(original_noise_cpu, recon_noise_cpu, latents_cpu, version="cosinesim", plot_dist=plot_dist)
100+
if plot_dist:
101+
plot_distribution(original_noise_cpu, recon_noise_cpu, latents_cpu, version="fourier")
101102

102103
# Return values with normalization
103104
latents_plot = latents_cpu[0:3]
@@ -117,7 +118,7 @@ def main():
117118
plot_dist=True,
118119
)
119120

120-
plot_and_save_image(image, recon_image, latents, recon_latents, show=True)
121+
plot_and_save_image(image, recon_image, latents, recon_latents, show=False)
121122

122123
print(f"generation time : {time}")
123124

code/pipeline_rf.py

Lines changed: 136 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,8 @@ def exact_inversion(
730730
callback_steps: int = 1,
731731
decoder_inv_steps: int = 100,
732732
forward_steps: int = 100,
733+
tuning_steps: int = 100,
734+
pnp_adjust: bool = True,
733735
):
734736
"""
735737
Exact inversion of RectifiedFlowPipeline. Gets input of 1,4,64,64 latents (which is denoised), and returns original latents by performing inversion
@@ -774,7 +776,7 @@ def exact_inversion(
774776
image = torch.Tensor(latents).permute(0, 3, 1, 2)
775777
image = image.to('cuda')
776778
torch.set_grad_enabled(True)
777-
latents = self.edcorrector(image, decoder_inv_steps=decoder_inv_steps)
779+
latents = self.edcorrector(image, decoder_inv_steps=decoder_inv_steps, verbose=verbose)
778780
torch.set_grad_enabled(False)
779781

780782
current_latents = latents # Save latents, this is our target
@@ -800,7 +802,7 @@ def exact_inversion(
800802
latents = latents - dt * v_pred
801803

802804
# Our work : perform forward step method
803-
latents = self.forward_step_method(latents, current_latents, t, dt, prompt_embeds=prompt_embeds, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, verbose=verbose)
805+
latents = self.forward_step_method(latents, current_latents, t, dt, prompt_embeds=prompt_embeds, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, verbose=verbose, pnp_adjust=pnp_adjust)
804806
else:
805807
# expand the latents if we are doing classifier free guidance
806808
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -822,15 +824,28 @@ def exact_inversion(
822824
# Our work : perform forward step method
823825
latents = self.forward_step_method(latents, current_latents, t, dt, prompt_embeds=prompt_embeds, do_classifier_free_guidance=do_classifier_free_guidance,
824826
guidance_scale=guidance_scale, verbose=verbose,
825-
steps=forward_steps)
826-
827+
steps=forward_steps, pnp_adjust=pnp_adjust)
827828

828829
if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
829830
progress_bar.update()
830831
if callback is not None and i % callback_steps == 0:
831832
step_idx = i // getattr(self.scheduler, "order", 1)
832833
callback(step_idx, t, latents)
833834

835+
# Add another procedure, end-to-end correction of noise
836+
torch.set_grad_enabled(True)
837+
if input_type == "images":
838+
latents = self.one_step_inversion_tuning_sampler(
839+
prompt=prompt,
840+
num_inference_steps=num_inference_steps,
841+
guidance_scale=guidance_scale,
842+
negative_prompt=negative_prompt,
843+
latents=latents,
844+
image=image,
845+
steps=tuning_steps,
846+
)
847+
torch.set_grad_enabled(False)
848+
834849
# Offload all models
835850
self.maybe_free_model_hooks()
836851

@@ -861,6 +876,115 @@ def exact_inversion(
861876

862877
return latents, image
863878

879+
def one_step_inversion_tuning_sampler(
880+
self,
881+
prompt: Union[str, List[str]] = None,
882+
height: Optional[int] = None,
883+
width: Optional[int] = None,
884+
num_inference_steps: int = 50,
885+
guidance_scale: float = 7.5,
886+
negative_prompt: Optional[Union[str, List[str]]] = None,
887+
num_images_per_prompt: Optional[int] = 1,
888+
eta: float = 0.0,
889+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
890+
latents: Optional[torch.FloatTensor] = None,
891+
image: Optional[torch.FloatTensor] = None,
892+
steps: int = 100,
893+
prompt_embeds: Optional[torch.FloatTensor] = None,
894+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
895+
output_type: Optional[str] = "pil",
896+
return_dict: bool = True,
897+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
898+
callback_steps: int = 1,
899+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
900+
guidance_rescale: float = 0.0,
901+
verbose: bool = True
902+
):
903+
r"""
904+
The simplified call function to the pipeline for tuning inversion. Creates a network form.
905+
Inputs:
906+
latents - initial noise obtained by inverse process
907+
image - target image to match
908+
909+
Returns:
910+
latents - fine-tuned latents
911+
"""
912+
913+
device = self._execution_device
914+
do_classifier_free_guidance = guidance_scale > 1.0
915+
916+
text_encoder_lora_scale = (
917+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
918+
)
919+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
920+
prompt,
921+
device,
922+
num_images_per_prompt,
923+
do_classifier_free_guidance,
924+
negative_prompt,
925+
prompt_embeds=prompt_embeds,
926+
negative_prompt_embeds=negative_prompt_embeds,
927+
lora_scale=text_encoder_lora_scale,
928+
)
929+
930+
if do_classifier_free_guidance:
931+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
932+
933+
timesteps = [(1. - i/num_inference_steps) * 1000. for i in range(num_inference_steps)]
934+
935+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
936+
dt = 1.0 / num_inference_steps
937+
t = timesteps[0]
938+
939+
# Performing gradient descent, to tune the latents
940+
image_answer = image.clone()
941+
do_denormalize = [True] * image.shape[0]
942+
input = copy.deepcopy(latents)
943+
unet = copy.deepcopy(self.unet)
944+
vae = copy.deepcopy(self.vae)
945+
input.requires_grad_(True)
946+
947+
loss_function = torch.nn.MSELoss(reduction='mean')
948+
949+
optimizer = torch.optim.Adam([input], lr=0.01)
950+
#lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=100)
951+
952+
for i in range(steps):
953+
latent_model_input = torch.cat([input] * 2) if do_classifier_free_guidance else input
954+
vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * t
955+
v_pred = unet(latent_model_input, vec_t, encoder_hidden_states=prompt_embeds).sample
956+
v_pred = v_pred.detach()
957+
958+
# perform guidance
959+
if do_classifier_free_guidance:
960+
v_pred_neg, v_pred_text = v_pred.chunk(2)
961+
v_pred = v_pred_neg + guidance_scale * (v_pred_text - v_pred_neg)
962+
963+
temp = input + dt * v_pred
964+
# Stop here, check by below
965+
# visual = self.image_processor.postprocess(self.vae.decode(input/self.vae.config.scaling_factor, return_dict=False)[0].detach().cpu())
966+
967+
image_recon = vae.decode(temp / self.vae.config.scaling_factor, return_dict=False)[0]
968+
969+
image_recon = self.image_processor.postprocess(image_recon, output_type="pt", do_denormalize=do_denormalize)
970+
971+
loss = loss_function(image_recon, image_answer)
972+
973+
optimizer.zero_grad()
974+
loss.backward()
975+
optimizer.step()
976+
# lr_scheduler.step()
977+
978+
if verbose:
979+
print(f"tuning, {i}, {loss.item()}")
980+
981+
input.detach()
982+
983+
# Offload all models
984+
self.maybe_free_model_hooks()
985+
986+
return input
987+
864988
@torch.inference_mode()
865989
def forward_step_method(
866990
self,
@@ -874,7 +998,8 @@ def forward_step_method(
874998
warmup_time = 0,
875999
steps=100,
8761000
original_step_size=0.1, step_size=0.5,
877-
factor=0.5, patience=15, th=1e-3
1001+
factor=0.5, patience=15, th=1e-3,
1002+
pnp_adjust=False,
8781003
):
8791004
"""
8801005
The forward step method assumes that current_latents are at right place(even on multistep), then map latents correctly to current latents
@@ -908,13 +1033,16 @@ def forward_step_method(
9081033
step_size = step_scheduler.step(loss)
9091034

9101035
# progress_bar.update()
1036+
if pnp_adjust:
1037+
add_noise = randn_tensor(latents_s.shape, device=latents_s.device, dtype=latents_s.dtype)
1038+
latents_s = latents_s + 0.01 * add_noise
9111039

9121040
if verbose:
9131041
print(i, ((latents_t - current_latents).norm()/current_latents.norm()).item(), step_size)
9141042

9151043
return latents_s
9161044

917-
def edcorrector(self, x, decoder_inv_steps=100):
1045+
def edcorrector(self, x, decoder_inv_steps=100, verbose=False):
9181046
"""
9191047
edcorrector calculates latents z of the image x by solving optimization problem ||E(x)-z||,
9201048
not by directly encoding with VAE encoder. "Decoder inversion"
@@ -947,7 +1075,8 @@ def edcorrector(self, x, decoder_inv_steps=100):
9471075
optimizer.step()
9481076
lr_scheduler.step()
9491077

950-
print(f"{i}, {loss.item()}")
1078+
if verbose:
1079+
print(f"{i}, {loss.item()}")
9511080

9521081
return z
9531082

File renamed without changes.

0 commit comments

Comments
 (0)