@@ -730,6 +730,8 @@ def exact_inversion(
730
730
callback_steps : int = 1 ,
731
731
decoder_inv_steps : int = 100 ,
732
732
forward_steps : int = 100 ,
733
+ tuning_steps : int = 100 ,
734
+ pnp_adjust : bool = True ,
733
735
):
734
736
"""
735
737
Exact inversion of RectifiedFlowPipeline. Gets input of 1,4,64,64 latents (which is denoised), and returns original latents by performing inversion
@@ -774,7 +776,7 @@ def exact_inversion(
774
776
image = torch .Tensor (latents ).permute (0 , 3 , 1 , 2 )
775
777
image = image .to ('cuda' )
776
778
torch .set_grad_enabled (True )
777
- latents = self .edcorrector (image , decoder_inv_steps = decoder_inv_steps )
779
+ latents = self .edcorrector (image , decoder_inv_steps = decoder_inv_steps , verbose = verbose )
778
780
torch .set_grad_enabled (False )
779
781
780
782
current_latents = latents # Save latents, this is our target
@@ -800,7 +802,7 @@ def exact_inversion(
800
802
latents = latents - dt * v_pred
801
803
802
804
# Our work : perform forward step method
803
- latents = self .forward_step_method (latents , current_latents , t , dt , prompt_embeds = prompt_embeds , do_classifier_free_guidance = do_classifier_free_guidance , guidance_scale = guidance_scale , verbose = verbose )
805
+ latents = self .forward_step_method (latents , current_latents , t , dt , prompt_embeds = prompt_embeds , do_classifier_free_guidance = do_classifier_free_guidance , guidance_scale = guidance_scale , verbose = verbose , pnp_adjust = pnp_adjust )
804
806
else :
805
807
# expand the latents if we are doing classifier free guidance
806
808
latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
@@ -822,15 +824,28 @@ def exact_inversion(
822
824
# Our work : perform forward step method
823
825
latents = self .forward_step_method (latents , current_latents , t , dt , prompt_embeds = prompt_embeds , do_classifier_free_guidance = do_classifier_free_guidance ,
824
826
guidance_scale = guidance_scale , verbose = verbose ,
825
- steps = forward_steps )
826
-
827
+ steps = forward_steps , pnp_adjust = pnp_adjust )
827
828
828
829
if i == len (timesteps ) - 1 or ((i + 1 ) % self .scheduler .order == 0 ):
829
830
progress_bar .update ()
830
831
if callback is not None and i % callback_steps == 0 :
831
832
step_idx = i // getattr (self .scheduler , "order" , 1 )
832
833
callback (step_idx , t , latents )
833
834
835
+ # Add another procedure, end-to-end correction of noise
836
+ torch .set_grad_enabled (True )
837
+ if input_type == "images" :
838
+ latents = self .one_step_inversion_tuning_sampler (
839
+ prompt = prompt ,
840
+ num_inference_steps = num_inference_steps ,
841
+ guidance_scale = guidance_scale ,
842
+ negative_prompt = negative_prompt ,
843
+ latents = latents ,
844
+ image = image ,
845
+ steps = tuning_steps ,
846
+ )
847
+ torch .set_grad_enabled (False )
848
+
834
849
# Offload all models
835
850
self .maybe_free_model_hooks ()
836
851
@@ -861,6 +876,115 @@ def exact_inversion(
861
876
862
877
return latents , image
863
878
879
+ def one_step_inversion_tuning_sampler (
880
+ self ,
881
+ prompt : Union [str , List [str ]] = None ,
882
+ height : Optional [int ] = None ,
883
+ width : Optional [int ] = None ,
884
+ num_inference_steps : int = 50 ,
885
+ guidance_scale : float = 7.5 ,
886
+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
887
+ num_images_per_prompt : Optional [int ] = 1 ,
888
+ eta : float = 0.0 ,
889
+ generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
890
+ latents : Optional [torch .FloatTensor ] = None ,
891
+ image : Optional [torch .FloatTensor ] = None ,
892
+ steps : int = 100 ,
893
+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
894
+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
895
+ output_type : Optional [str ] = "pil" ,
896
+ return_dict : bool = True ,
897
+ callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
898
+ callback_steps : int = 1 ,
899
+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
900
+ guidance_rescale : float = 0.0 ,
901
+ verbose : bool = True
902
+ ):
903
+ r"""
904
+ The simplified call function to the pipeline for tuning inversion. Creates a network form.
905
+ Inputs:
906
+ latents - initial noise obtained by inverse process
907
+ image - target image to match
908
+
909
+ Returns:
910
+ latents - fine-tuned latents
911
+ """
912
+
913
+ device = self ._execution_device
914
+ do_classifier_free_guidance = guidance_scale > 1.0
915
+
916
+ text_encoder_lora_scale = (
917
+ cross_attention_kwargs .get ("scale" , None ) if cross_attention_kwargs is not None else None
918
+ )
919
+ prompt_embeds , negative_prompt_embeds = self .encode_prompt (
920
+ prompt ,
921
+ device ,
922
+ num_images_per_prompt ,
923
+ do_classifier_free_guidance ,
924
+ negative_prompt ,
925
+ prompt_embeds = prompt_embeds ,
926
+ negative_prompt_embeds = negative_prompt_embeds ,
927
+ lora_scale = text_encoder_lora_scale ,
928
+ )
929
+
930
+ if do_classifier_free_guidance :
931
+ prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
932
+
933
+ timesteps = [(1. - i / num_inference_steps ) * 1000. for i in range (num_inference_steps )]
934
+
935
+ extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
936
+ dt = 1.0 / num_inference_steps
937
+ t = timesteps [0 ]
938
+
939
+ # Performing gradient descent, to tune the latents
940
+ image_answer = image .clone ()
941
+ do_denormalize = [True ] * image .shape [0 ]
942
+ input = copy .deepcopy (latents )
943
+ unet = copy .deepcopy (self .unet )
944
+ vae = copy .deepcopy (self .vae )
945
+ input .requires_grad_ (True )
946
+
947
+ loss_function = torch .nn .MSELoss (reduction = 'mean' )
948
+
949
+ optimizer = torch .optim .Adam ([input ], lr = 0.01 )
950
+ #lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=100)
951
+
952
+ for i in range (steps ):
953
+ latent_model_input = torch .cat ([input ] * 2 ) if do_classifier_free_guidance else input
954
+ vec_t = torch .ones ((latent_model_input .shape [0 ],), device = latents .device ) * t
955
+ v_pred = unet (latent_model_input , vec_t , encoder_hidden_states = prompt_embeds ).sample
956
+ v_pred = v_pred .detach ()
957
+
958
+ # perform guidance
959
+ if do_classifier_free_guidance :
960
+ v_pred_neg , v_pred_text = v_pred .chunk (2 )
961
+ v_pred = v_pred_neg + guidance_scale * (v_pred_text - v_pred_neg )
962
+
963
+ temp = input + dt * v_pred
964
+ # Stop here, check by below
965
+ # visual = self.image_processor.postprocess(self.vae.decode(input/self.vae.config.scaling_factor, return_dict=False)[0].detach().cpu())
966
+
967
+ image_recon = vae .decode (temp / self .vae .config .scaling_factor , return_dict = False )[0 ]
968
+
969
+ image_recon = self .image_processor .postprocess (image_recon , output_type = "pt" , do_denormalize = do_denormalize )
970
+
971
+ loss = loss_function (image_recon , image_answer )
972
+
973
+ optimizer .zero_grad ()
974
+ loss .backward ()
975
+ optimizer .step ()
976
+ # lr_scheduler.step()
977
+
978
+ if verbose :
979
+ print (f"tuning, { i } , { loss .item ()} " )
980
+
981
+ input .detach ()
982
+
983
+ # Offload all models
984
+ self .maybe_free_model_hooks ()
985
+
986
+ return input
987
+
864
988
@torch .inference_mode ()
865
989
def forward_step_method (
866
990
self ,
@@ -874,7 +998,8 @@ def forward_step_method(
874
998
warmup_time = 0 ,
875
999
steps = 100 ,
876
1000
original_step_size = 0.1 , step_size = 0.5 ,
877
- factor = 0.5 , patience = 15 , th = 1e-3
1001
+ factor = 0.5 , patience = 15 , th = 1e-3 ,
1002
+ pnp_adjust = False ,
878
1003
):
879
1004
"""
880
1005
The forward step method assumes that current_latents are at right place(even on multistep), then map latents correctly to current latents
@@ -908,13 +1033,16 @@ def forward_step_method(
908
1033
step_size = step_scheduler .step (loss )
909
1034
910
1035
# progress_bar.update()
1036
+ if pnp_adjust :
1037
+ add_noise = randn_tensor (latents_s .shape , device = latents_s .device , dtype = latents_s .dtype )
1038
+ latents_s = latents_s + 0.01 * add_noise
911
1039
912
1040
if verbose :
913
1041
print (i , ((latents_t - current_latents ).norm ()/ current_latents .norm ()).item (), step_size )
914
1042
915
1043
return latents_s
916
1044
917
- def edcorrector (self , x , decoder_inv_steps = 100 ):
1045
+ def edcorrector (self , x , decoder_inv_steps = 100 , verbose = False ):
918
1046
"""
919
1047
edcorrector calculates latents z of the image x by solving optimization problem ||E(x)-z||,
920
1048
not by directly encoding with VAE encoder. "Decoder inversion"
@@ -947,7 +1075,8 @@ def edcorrector(self, x, decoder_inv_steps=100):
947
1075
optimizer .step ()
948
1076
lr_scheduler .step ()
949
1077
950
- print (f"{ i } , { loss .item ()} " )
1078
+ if verbose :
1079
+ print (f"{ i } , { loss .item ()} " )
951
1080
952
1081
return z
953
1082
0 commit comments