Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ checkpoint/
demo/
colab_demo/
sample.sh
.vscode
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
32 changes: 19 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<br>

<p align="center">
<img src="https://github.com/ermongroup/SDEdit/blob/main/images/sde_animation.gif" width="320"/>
<img src="images/sde_animation.gif" width="320"/>
</p>

[**Project**](https://sde-image-editing.github.io/) | [**Paper**](https://arxiv.org/abs/2108.01073) | [**Colab**](https://colab.research.google.com/drive/1KkLS53PndXKQpPlS1iK-k1nRQYmlb4aO?usp=sharing)
Expand All @@ -16,28 +16,28 @@ Stanford and CMU


<p align="center">
<img src="https://github.com/ermongroup/SDEdit/blob/main/images/teaser.jpg" />
<img src="images/teaser.jpg" />
</p>

Recently, SDEdit has also been applied to text-guided image editing with large-scale text-to-image models. Notable examples include <a href="https://en.wikipedia.org/wiki/Stable_Diffusion">Stable Diffusion</a>'s img2img function (see <a href="https://github.com/CompVis/stable-diffusion#image-modification-with-stable-diffusion">here</a>), <a href="https://arxiv.org/abs/2112.10741">GLIDE</a>, and <a href="https://arxiv.org/abs/2210.03142">distilled-SD</a>. The below example comes from <a href="https://arxiv.org/abs/2210.03142">distilled-SD</a>.

<p align="center">
<img src="https://github.com/ermongroup/SDEdit/blob/main/images/text_guided_img2img.png" />
<img src="images/text_guided_img2img.png" />
</p>


## Overview
The key intuition of SDEdit is to "hijack" the reverse stochastic process of SDE-based generative models, as illustrated in the figure below. Given an input image for editing, such as a stroke painting or an image with color strokes, we can add a suitable amount of noise to make its artifacts undetectable, while still preserving the overall structure of the image. We then initialize the reverse SDE with this noisy input, and simulate the reverse process to obtain a denoised image of high quality. The final output is realistic while resembling the overall image structure of the input.

<p align="center">
<img src="https://github.com/ermongroup/SDEdit/blob/main/images/sde_stroke_generation.jpg" />
<img src="images/sde_stroke_generation.jpg" />
</p>

## Getting Started
The code will automatically download pretrained SDE (VP) PyTorch models on
[CelebA-HQ](https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt),
[LSUN bedroom](https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt),
and [LSUN church outdoor](https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt).
[CelebA-HQ](https://huggingface.co/XUXR/SDEdit/resolve/main/celeba_hq.ckpt),
[LSUN bedroom](https://huggingface.co/XUXR/SDEdit/blob/main/celeba_hq.ckpt),
and [LSUN church outdoor](https://huggingface.co/XUXR/SDEdit/blob/main/ema_lsun_church.ckpt).

### Data format
We save the image and the corresponding mask in an array format ``[image, mask]``, where
Expand All @@ -57,31 +57,37 @@ SDEdit can synthesize multiple diverse outputs for each input on LSUN bedroom, L
To generate results on LSUN datasets, please run

```
python main.py --exp ./runs/ --config bedroom.yml --sample -i images --npy_name lsun_bedroom1 --sample_step 3 --t 500 --ni
python main.py --exp ./runs/ --config celeba.yml --img <path_to_img.jpg> --sample -i images --sample_step 3 --t 500 --ni
```
```
python main.py --exp ./runs/ --config church.yml --sample -i images --npy_name lsun_church --sample_step 3 --t 500 --ni
python main.py --exp ./runs/ --config church.yml --sample -i images --img <path_to_img.jpg> --sample_step 3 --t 500 --ni
```

Use all images in a directory, please run

```
python main.py --exp ./runs/ --config church.yml --sample -i images --sample_step 1 --t 300 --ni --init_dir <path to dir of images>
```
<p align="center">
<img src="https://github.com/ermongroup/SDEdit/blob/main/images/stroke_based_generation.jpg" width="800">
<img src="images/stroke_based_generation.jpg" width="800">
</p>

## Stroke-based image editing
Given an input image with user strokes, we want to manipulate a natural input image based on the user's edit.
SDEdit can generate image edits that are both realistic and faithful (to the user edit), while avoid introducing undesired changes.
<p align="center">
<img src="https://github.com/ermongroup/SDEdit/blob/main/images/stroke_edit.jpg" width="800">
<img src="images/stroke_edit.jpg" width="800">
</p>

To perform stroke-based image editing, run

```
python main.py --exp ./runs/ --config church.yml --sample -i images --npy_name lsun_edit --sample_step 3 --t 500 --ni
python main.py --exp ./runs/ --config church.yml --sample -i images --img <path to image> --sample_step 3 --t 500 --ni
```

## Additional results
<p align="center">
<img src="https://github.com/ermongroup/SDEdit/blob/main/images/stroke_generation_extra.jpg" width="800">
<img src="images/stroke_generation_extra.jpg" width="800">
</p>

## References
Expand Down
2 changes: 1 addition & 1 deletion configs/bedroom.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ diffusion:
num_diffusion_timesteps: 1000

sampling:
batch_size: 8
batch_size: 2
last_only: True
2 changes: 1 addition & 1 deletion configs/celeba.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ diffusion:
num_diffusion_timesteps: 1000

sampling:
batch_size: 8
batch_size: 2
last_only: True
2 changes: 1 addition & 1 deletion configs/church.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ diffusion:
num_diffusion_timesteps: 1000

sampling:
batch_size: 8
batch_size: 3
last_only: True
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def parse_args_and_config():
parser.add_argument('--sample', action='store_true', help='Whether to produce samples from the model')
parser.add_argument('-i', '--image_folder', type=str, default='images', help="The folder name of samples")
parser.add_argument('--ni', action='store_true', help="No interaction. Suitable for Slurm Job launcher")
parser.add_argument('--npy_name', type=str, required=True)
parser.add_argument('--sample_step', type=int, default=3, help='Total sampling steps')
parser.add_argument('--t', type=int, default=400, help='Sampling noise scale')
parser.add_argument('--img', type=str, default="image.jpg", help='Image path')
parser.add_argument('--init_dir', type=str, default=None, help='use all images in the directory for initialization')
args = parser.parse_args()

# parse config file
Expand Down
2 changes: 1 addition & 1 deletion models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(self, config):
padding=1)

def forward(self, x, t):
assert x.shape[2] == x.shape[3] == self.resolution
assert x.shape[2] == x.shape[3] == self.resolution, "x.shape: {}, resolution: {}".format(x.shape, self.resolution)

# timestep embedding
temb = get_timestep_embedding(t, self.ch)
Expand Down
14 changes: 7 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
matplotlib==3.3.4
numpy==1.19.5
PyYAML==5.4.1
tensorboard==2.4.1
torch==1.5.0+cu101
torchvision==0.6.0+cu101
tqdm==4.59.0
matplotlib
numpy
PyYAML
tensorboard
torch
torchvision
tqdm
109 changes: 66 additions & 43 deletions runners/image_editing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
import os
import numpy as np
from tqdm import tqdm

from PIL import Image
import torch
import torchvision.utils as tvu

from models.diffusion import Model
from functions.process_data import *


def load_image(image_path, dev):
image = Image.open(image_path).convert('RGB')
image = image.resize((256, 256)) # Resize the image to 256x256
image = np.array(image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image).permute(2, 0, 1).to(dev) # Convert to PyTorch tensor and rearrange dimensions
return image_tensor

def create_full_mask(image_tensor, dev):
mask = torch.zeros_like(image_tensor, device=dev)
return mask

def image_and_mask(image_path, dev):
image_tensor = load_image(image_path, dev)
mask_tensor = create_full_mask(image_tensor, dev)
return image_tensor, mask_tensor


def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
betas = np.linspace(beta_start, beta_end,
num_diffusion_timesteps, dtype=np.float64)
Expand Down Expand Up @@ -79,15 +96,20 @@ def __init__(self, args, config, device=None):
elif self.model_var_type == 'fixedsmall':
self.logvar = np.log(np.maximum(posterior_variance, 1e-20))

idir = self.args.init_dir
assert os.path.exists(self.args.img) or os.path.exists(idir), "Image path or directory does not exist"
self.img_list = [args.img] if idir is None else [os.path.join(idir, img) for img in os.listdir(idir)]


def image_editing_sample(self):
print("Loading model")
if self.config.data.dataset == "LSUN":
if self.config.data.category == "bedroom":
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
url = "https://huggingface.co/XUXR/SDEdit/resolve/main/lsun_bedroom.ckpt"
elif self.config.data.category == "church_outdoor":
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
url = "https://huggingface.co/XUXR/SDEdit/resolve/main/ema_lsun_church.ckpt"
elif self.config.data.dataset == "CelebA_HQ":
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
url = "https://huggingface.co/XUXR/SDEdit/resolve/main/celeba_hq.ckpt"
else:
raise ValueError

Expand All @@ -99,46 +121,47 @@ def image_editing_sample(self):
print("Model loaded")
ckpt_id = 0

download_process_data(path="colab_demo")
# download_process_data(path="colab_demo")
n = self.config.sampling.batch_size
model.eval()
print("Start sampling")
with torch.no_grad():
name = self.args.npy_name
[mask, img] = torch.load("colab_demo/{}.pth".format(name))

mask = mask.to(self.config.device)
img = img.to(self.config.device)
img = img.unsqueeze(dim=0)
img = img.repeat(n, 1, 1, 1)
x0 = img

tvu.save_image(x0, os.path.join(self.args.image_folder, f'original_input.png'))
x0 = (x0 - 0.5) * 2.

for it in range(self.args.sample_step):
e = torch.randn_like(x0)
total_noise_levels = self.args.t
a = (1 - self.betas).cumprod(dim=0)
x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt()
tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, f'init_{ckpt_id}.png'))

with tqdm(total=total_noise_levels, desc="Iteration {}".format(it)) as progress_bar:
for i in reversed(range(total_noise_levels)):
t = (torch.ones(n) * i).to(self.device)
x_ = image_editing_denoising_step_flexible_mask(x, t=t, model=model,
logvar=self.logvar,
betas=self.betas)
x = x0 * a[i].sqrt() + e * (1.0 - a[i]).sqrt()
x[:, (mask != 1.)] = x_[:, (mask != 1.)]
# added intermediate step vis
if (i - 99) % 100 == 0:
tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
f'noise_t_{i}_{it}.png'))
progress_bar.update(1)

x0[:, (mask != 1.)] = x[:, (mask != 1.)]
torch.save(x, os.path.join(self.args.image_folder,
f'samples_{it}.pth'))
tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
f'samples_{it}.png'))
for img_path in self.img_list:
save_folder = os.path.join(self.args.image_folder, os.path.basename(img_path).split('.')[0])
if not os.path.exists(save_folder):
os.makedirs(save_folder)
img, mask = image_and_mask(img_path, self.device)
ckpt_id += 1
img = img.unsqueeze(dim=0)
img = img.repeat(n, 1, 1, 1)
x0 = img

tvu.save_image(x0, os.path.join(save_folder, f'original_input.png'))
x0 = (x0 - 0.5) * 2.

for it in range(self.args.sample_step):
e = torch.randn_like(x0)
total_noise_levels = self.args.t
a = (1 - self.betas).cumprod(dim=0)
x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt()
tvu.save_image((x + 1) * 0.5, os.path.join(save_folder, f'init_{ckpt_id}.png'))

with tqdm(total=total_noise_levels, desc="Iteration {}".format(it), ) as progress_bar:
for i in reversed(range(total_noise_levels)):
t = (torch.ones(n) * i).to(self.device)
x_ = image_editing_denoising_step_flexible_mask(x, t=t, model=model,
logvar=self.logvar,
betas=self.betas)
x = x0 * a[i].sqrt() + e * (1.0 - a[i]).sqrt()
x[:, (mask != 1.)] = x_[:, (mask != 1.)]
# added intermediate step vis
if i % 100 == 0:
tvu.save_image((x + 1) * 0.5, os.path.join(save_folder,
f'noise_t_{i}_{it}.png'))
progress_bar.update(1)

x0[:, (mask != 1.)] = x[:, (mask != 1.)]
torch.save(x, os.path.join(save_folder,
f'samples_{it}.pth'))
# tvu.save_image((x + 1) * 0.5, os.path.join(save_folder,
# f'samples_{it}.png'))
17 changes: 15 additions & 2 deletions colab_utils/utils.py → utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

sys.path.append("../")

Expand All @@ -18,11 +19,23 @@
device = "cuda"


URL_MAP = {
"cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1",
"ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1",
"lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1",
"ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1",
"lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1",
"ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1",
"lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1",
"ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1",
}

# It's recommended to download the checkpoint files manually
def get_checkpoint(dataset, category):
if category == "bedroom":
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
url = "https://huggingface.co/gwang-kim/DiffusionCLIP-LSUN_Bedroom/resolve/main/bedroom.ckpt"
elif category == "church_outdoor":
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
url = r"C:\Users\Administrator\Downloads\ema_lsun_church.ckpt"
elif dataset == "CelebA_HQ":
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
else:
Expand Down