-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathip2p.py
222 lines (174 loc) · 7.66 KB
/
ip2p.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""InstructPix2Pix module"""
# Modified from https://github.com/ashawkey/stable-dreamfusion/blob/main/nerf/sd.py
import sys
from dataclasses import dataclass
from typing import Union
import torch
from rich.console import Console
from torch import Tensor, nn
from jaxtyping import Float
CONSOLE = Console(width=120)
try:
from diffusers import (
DDIMScheduler,
StableDiffusionInstructPix2PixPipeline,
)
from transformers import logging
except ImportError:
CONSOLE.print("[bold red]Missing Stable Diffusion packages.")
CONSOLE.print(r"Install using [yellow]pip install nerfstudio\[gen][/yellow]")
CONSOLE.print(r"or [yellow]pip install -e .\[gen][/yellow] if installing from source.")
sys.exit(1)
logging.set_verbosity_error()
IMG_DIM = 512
CONST_SCALE = 0.18215
DDIM_SOURCE = "CompVis/stable-diffusion-v1-4"
SD_SOURCE = "runwayml/stable-diffusion-v1-5"
CLIP_SOURCE = "openai/clip-vit-large-patch14"
IP2P_SOURCE = "timbrooks/instruct-pix2pix"
@dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
class InstructPix2Pix(nn.Module):
"""InstructPix2Pix implementation
Args:
device: device to use
num_train_timesteps: number of training timesteps
"""
def __init__(self, device: Union[torch.device, str], num_train_timesteps: int = 1000, ip2p_use_full_precision=False) -> None:
super().__init__()
self.device = device
self.num_train_timesteps = num_train_timesteps
self.ip2p_use_full_precision = ip2p_use_full_precision
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(IP2P_SOURCE, torch_dtype=torch.float16, safety_checker=None)
pipe.scheduler = DDIMScheduler.from_pretrained(DDIM_SOURCE, subfolder="scheduler")
pipe.scheduler.set_timesteps(100)
assert pipe is not None
pipe = pipe.to(self.device)
pipe.enable_xformers_memory_efficient_attention()
self.pipe = pipe
self.scheduler = pipe.scheduler
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # type: ignore
pipe.unet.eval()
pipe.vae.eval()
# use for improved quality at cost of higher memory
if self.ip2p_use_full_precision:
pipe.unet.float()
pipe.vae.float()
else:
if self.device.index:
pipe.enable_model_cpu_offload(self.device.index)
else:
pipe.enable_model_cpu_offload(0)
self.unet = pipe.unet
self.auto_encoder = pipe.vae
CONSOLE.print("InstructPix2Pix loaded!")
def edit_image(
self,
text_embeddings: Float[Tensor, "N max_length embed_dim"],
image: Float[Tensor, "BS 3 H W"],
image_cond: Float[Tensor, "BS 3 H W"],
guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5,
diffusion_steps: int = 20,
lower_bound: float = 0.70,
upper_bound: float = 0.98
) -> torch.Tensor:
"""Edit an image for Instruct-NeRF2NeRF using InstructPix2Pix
Args:
text_embeddings: Text embeddings
image: rendered image to edit
image_cond: corresponding training image to condition on
guidance_scale: text-guidance scale
image_guidance_scale: image-guidance scale
diffusion_steps: number of diffusion steps
lower_bound: lower bound for diffusion timesteps to use for image editing
upper_bound: upper bound for diffusion timesteps to use for image editing
Returns:
edited image
"""
min_step = int(self.num_train_timesteps * lower_bound)
max_step = int(self.num_train_timesteps * upper_bound)
# select t, set multi-step diffusion
T = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
self.scheduler.config.num_train_timesteps = T.item()
self.scheduler.set_timesteps(diffusion_steps)
with torch.no_grad():
# prepare image and image_cond latents
latents = self.imgs_to_latent(image)
image_cond_latents = self.prepare_image_latents(image_cond)
# add noise
noise = torch.randn_like(latents)
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0]) # type: ignore
# sections of code used from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
for i, t in enumerate(self.scheduler.timesteps):
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# pred noise
latent_model_input = torch.cat([latents] * 3)
latent_model_input = torch.cat([latent_model_input, image_cond_latents], dim=1)
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform classifier-free guidance
noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ guidance_scale * (noise_pred_text - noise_pred_image)
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond)
)
# get previous sample, continue loop
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# decode latents to get edited image
with torch.no_grad():
decoded_img = self.latents_to_img(latents)
return decoded_img
def latents_to_img(self, latents: Float[Tensor, "BS 4 H W"]) -> Float[Tensor, "BS 3 H W"]:
"""Convert latents to images
Args:
latents: Latents to convert
Returns:
Images
"""
latents = 1 / CONST_SCALE * latents
with torch.no_grad():
imgs = self.auto_encoder.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def imgs_to_latent(self, imgs: Float[Tensor, "BS 3 H W"]) -> Float[Tensor, "BS 4 H W"]:
"""Convert images to latents
Args:
imgs: Images to convert
Returns:
Latents
"""
imgs = 2 * imgs - 1
posterior = self.auto_encoder.encode(imgs).latent_dist
latents = posterior.sample() * CONST_SCALE
return latents
def prepare_image_latents(self, imgs: Float[Tensor, "BS 3 H W"]) -> Float[Tensor, "BS 4 H W"]:
"""Convert conditioning image to latents used for classifier-free guidance
Args:
imgs: Images to convert
Returns:
Latents
"""
imgs = 2 * imgs - 1
image_latents = self.auto_encoder.encode(imgs).latent_dist.mode()
uncond_image_latents = torch.zeros_like(image_latents)
image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
return image_latents
def forward(self):
"""Not implemented since we only want the parameter saving of the nn module, but not forward()"""
raise NotImplementedError