Skip to content

Commit

Permalink
trying
Browse files Browse the repository at this point in the history
  • Loading branch information
khlee0192 committed Feb 6, 2024
1 parent 5e675cb commit 65118fa
Show file tree
Hide file tree
Showing 16 changed files with 149 additions and 148 deletions.
Binary file modified code/__pycache__/pipeline_rf.cpython-310.pyc
Binary file not shown.
Binary file added code/__pycache__/utils.cpython-310.pyc
Binary file not shown.
11 changes: 6 additions & 5 deletions code/main_simple_inversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## This code is for genearting basic image sample, adjusted from local_gradio.py

from pipeline_rf_adjusting import RectifiedInversableFlowPipeline
from pipeline_rf import RectifiedInversableFlowPipeline

import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
Expand Down Expand Up @@ -50,8 +50,8 @@ def set_and_generate_image_then_reverse(seed, prompt, inversion_prompt, randomiz
use_random_initial_noise=False,
decoder_inv_steps=10,
forward_steps=100,
tuning_steps=50,
pnp_adjust=False,
tuning_steps=10,
pnp_adjust=True,
)

print(f"TOT of inversion {torch.mean((recon_latents-original_latents)**2)/torch.norm(original_latents)**2}")
Expand Down Expand Up @@ -97,7 +97,8 @@ def set_and_generate_image_then_reverse(seed, prompt, inversion_prompt, randomiz
# print(f'Shapiro-Wilk test for Reconstructed Latents: W={stat_recon:.8f}, p-value={p_value_recon:.8f}')

# Section : Check with plot distribution
plot_distribution(original_noise_cpu, recon_noise_cpu, latents_cpu, version="cosinesim", plot_dist=plot_dist)
if plot_dist:
plot_distribution(original_noise_cpu, recon_noise_cpu, latents_cpu, version="fourier")

# Return values with normalization
latents_plot = latents_cpu[0:3]
Expand All @@ -117,7 +118,7 @@ def main():
plot_dist=True,
)

plot_and_save_image(image, recon_image, latents, recon_latents, show=True)
plot_and_save_image(image, recon_image, latents, recon_latents, show=False)

print(f"generation time : {time}")

Expand Down
143 changes: 136 additions & 7 deletions code/pipeline_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,8 @@ def exact_inversion(
callback_steps: int = 1,
decoder_inv_steps: int = 100,
forward_steps: int = 100,
tuning_steps: int = 100,
pnp_adjust: bool = True,
):
"""
Exact inversion of RectifiedFlowPipeline. Gets input of 1,4,64,64 latents (which is denoised), and returns original latents by performing inversion
Expand Down Expand Up @@ -774,7 +776,7 @@ def exact_inversion(
image = torch.Tensor(latents).permute(0, 3, 1, 2)
image = image.to('cuda')
torch.set_grad_enabled(True)
latents = self.edcorrector(image, decoder_inv_steps=decoder_inv_steps)
latents = self.edcorrector(image, decoder_inv_steps=decoder_inv_steps, verbose=verbose)
torch.set_grad_enabled(False)

current_latents = latents # Save latents, this is our target
Expand All @@ -800,7 +802,7 @@ def exact_inversion(
latents = latents - dt * v_pred

# Our work : perform forward step method
latents = self.forward_step_method(latents, current_latents, t, dt, prompt_embeds=prompt_embeds, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, verbose=verbose)
latents = self.forward_step_method(latents, current_latents, t, dt, prompt_embeds=prompt_embeds, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, verbose=verbose, pnp_adjust=pnp_adjust)
else:
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand All @@ -822,15 +824,28 @@ def exact_inversion(
# Our work : perform forward step method
latents = self.forward_step_method(latents, current_latents, t, dt, prompt_embeds=prompt_embeds, do_classifier_free_guidance=do_classifier_free_guidance,
guidance_scale=guidance_scale, verbose=verbose,
steps=forward_steps)

steps=forward_steps, pnp_adjust=pnp_adjust)

if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

# Add another procedure, end-to-end correction of noise
torch.set_grad_enabled(True)
if input_type == "images":
latents = self.one_step_inversion_tuning_sampler(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
latents=latents,
image=image,
steps=tuning_steps,
)
torch.set_grad_enabled(False)

# Offload all models
self.maybe_free_model_hooks()

Expand Down Expand Up @@ -861,6 +876,115 @@ def exact_inversion(

return latents, image

def one_step_inversion_tuning_sampler(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
image: Optional[torch.FloatTensor] = None,
steps: int = 100,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
verbose: bool = True
):
r"""
The simplified call function to the pipeline for tuning inversion. Creates a network form.
Inputs:
latents - initial noise obtained by inverse process
image - target image to match
Returns:
latents - fine-tuned latents
"""

device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0

text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)

if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

timesteps = [(1. - i/num_inference_steps) * 1000. for i in range(num_inference_steps)]

extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
dt = 1.0 / num_inference_steps
t = timesteps[0]

# Performing gradient descent, to tune the latents
image_answer = image.clone()
do_denormalize = [True] * image.shape[0]
input = copy.deepcopy(latents)
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)
input.requires_grad_(True)

loss_function = torch.nn.MSELoss(reduction='mean')

optimizer = torch.optim.Adam([input], lr=0.01)
#lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=100)

for i in range(steps):
latent_model_input = torch.cat([input] * 2) if do_classifier_free_guidance else input
vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * t
v_pred = unet(latent_model_input, vec_t, encoder_hidden_states=prompt_embeds).sample
v_pred = v_pred.detach()

# perform guidance
if do_classifier_free_guidance:
v_pred_neg, v_pred_text = v_pred.chunk(2)
v_pred = v_pred_neg + guidance_scale * (v_pred_text - v_pred_neg)

temp = input + dt * v_pred
# Stop here, check by below
# visual = self.image_processor.postprocess(self.vae.decode(input/self.vae.config.scaling_factor, return_dict=False)[0].detach().cpu())

image_recon = vae.decode(temp / self.vae.config.scaling_factor, return_dict=False)[0]

image_recon = self.image_processor.postprocess(image_recon, output_type="pt", do_denormalize=do_denormalize)

loss = loss_function(image_recon, image_answer)

optimizer.zero_grad()
loss.backward()
optimizer.step()
# lr_scheduler.step()

if verbose:
print(f"tuning, {i}, {loss.item()}")

input.detach()

# Offload all models
self.maybe_free_model_hooks()

return input

@torch.inference_mode()
def forward_step_method(
self,
Expand All @@ -874,7 +998,8 @@ def forward_step_method(
warmup_time = 0,
steps=100,
original_step_size=0.1, step_size=0.5,
factor=0.5, patience=15, th=1e-3
factor=0.5, patience=15, th=1e-3,
pnp_adjust=False,
):
"""
The forward step method assumes that current_latents are at right place(even on multistep), then map latents correctly to current latents
Expand Down Expand Up @@ -908,13 +1033,16 @@ def forward_step_method(
step_size = step_scheduler.step(loss)

# progress_bar.update()
if pnp_adjust:
add_noise = randn_tensor(latents_s.shape, device=latents_s.device, dtype=latents_s.dtype)
latents_s = latents_s + 0.01 * add_noise

if verbose:
print(i, ((latents_t - current_latents).norm()/current_latents.norm()).item(), step_size)

return latents_s

def edcorrector(self, x, decoder_inv_steps=100):
def edcorrector(self, x, decoder_inv_steps=100, verbose=False):
"""
edcorrector calculates latents z of the image x by solving optimization problem ||E(x)-z||,
not by directly encoding with VAE encoder. "Decoder inversion"
Expand Down Expand Up @@ -947,7 +1075,8 @@ def edcorrector(self, x, decoder_inv_steps=100):
optimizer.step()
lr_scheduler.step()

print(f"{i}, {loss.item()}")
if verbose:
print(f"{i}, {loss.item()}")

return z

Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 65118fa

Please sign in to comment.