Skip to content

Commit 4d02e4b

Browse files
committed
first commit
0 parents  commit 4d02e4b

12 files changed

+1247
-0
lines changed

CODE_OF_CONDUCT.md

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Code of Conduct
2+
3+
## Our Pledge
4+
5+
In the interest of fostering an open and welcoming environment, we as
6+
contributors and maintainers pledge to making participation in our project and
7+
our community a harassment-free experience for everyone, regardless of age, body
8+
size, disability, ethnicity, sex characteristics, gender identity and expression,
9+
level of experience, education, socio-economic status, nationality, personal
10+
appearance, race, religion, or sexual identity and orientation.
11+
12+
## Our Standards
13+
14+
Examples of behavior that contributes to creating a positive environment
15+
include:
16+
17+
* Using welcoming and inclusive language
18+
* Being respectful of differing viewpoints and experiences
19+
* Gracefully accepting constructive criticism
20+
* Focusing on what is best for the community
21+
* Showing empathy towards other community members
22+
23+
Examples of unacceptable behavior by participants include:
24+
25+
* The use of sexualized language or imagery and unwelcome sexual attention or
26+
advances
27+
* Trolling, insulting/derogatory comments, and personal or political attacks
28+
* Public or private harassment
29+
* Publishing others' private information, such as a physical or electronic
30+
address, without explicit permission
31+
* Other conduct which could reasonably be considered inappropriate in a
32+
professional setting
33+
34+
## Our Responsibilities
35+
36+
Project maintainers are responsible for clarifying the standards of acceptable
37+
behavior and are expected to take appropriate and fair corrective action in
38+
response to any instances of unacceptable behavior.
39+
40+
Project maintainers have the right and responsibility to remove, edit, or
41+
reject comments, commits, code, wiki edits, issues, and other contributions
42+
that are not aligned to this Code of Conduct, or to ban temporarily or
43+
permanently any contributor for other behaviors that they deem inappropriate,
44+
threatening, offensive, or harmful.
45+
46+
## Scope
47+
48+
This Code of Conduct applies within all project spaces, and it also applies when
49+
an individual is representing the project or its community in public spaces.
50+
Examples of representing a project or community include using an official
51+
project e-mail address, posting via an official social media account, or acting
52+
as an appointed representative at an online or offline event. Representation of
53+
a project may be further defined and clarified by project maintainers.
54+
55+
## Enforcement
56+
57+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
58+
reported by contacting the open source team at [[email protected]](mailto:[email protected]). All
59+
complaints will be reviewed and investigated and will result in a response that
60+
is deemed necessary and appropriate to the circumstances. The project team is
61+
obligated to maintain confidentiality with regard to the reporter of an incident.
62+
Further details of specific enforcement policies may be posted separately.
63+
64+
Project maintainers who do not follow or enforce the Code of Conduct in good
65+
faith may face temporary or permanent repercussions as determined by other
66+
members of the project's leadership.
67+
68+
## Attribution
69+
70+
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71+
available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)

CONTRIBUTING.md

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Contribution Guide
2+
3+
Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
4+
5+
While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6+
7+
## Before you get started
8+
9+
By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10+
11+
We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).

LICENSE

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
Copyright (C) 2024 Apple Inc. All Rights Reserved.
2+
3+
IMPORTANT: This Apple software is supplied to you by Apple
4+
Inc. ("Apple") in consideration of your agreement to the following
5+
terms, and your use, installation, modification or redistribution of
6+
this Apple software constitutes acceptance of these terms. If you do
7+
not agree with these terms, please do not use, install, modify or
8+
redistribute this Apple software.
9+
10+
In consideration of your agreement to abide by the following terms, and
11+
subject to these terms, Apple grants you a personal, non-exclusive
12+
license, under Apple's copyrights in this original Apple software (the
13+
"Apple Software"), to use, reproduce, modify and redistribute the Apple
14+
Software, with or without modifications, in source and/or binary forms;
15+
provided that if you redistribute the Apple Software in its entirety and
16+
without modifications, you must retain this notice and the following
17+
text and disclaimers in all such redistributions of the Apple Software.
18+
Neither the name, trademarks, service marks or logos of Apple Inc. may
19+
be used to endorse or promote products derived from the Apple Software
20+
without specific prior written permission from Apple. Except as
21+
expressly stated in this notice, no other rights or licenses, express or
22+
implied, are granted by Apple herein, including but not limited to any
23+
patent rights that may be infringed by your derivative works or by other
24+
works in which the Apple Software may be incorporated.
25+
26+
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27+
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28+
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29+
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30+
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31+
32+
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33+
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35+
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36+
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37+
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38+
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39+
POSSIBILITY OF SUCH DAMAGE.

README.md

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Normalizing Flows are Capable Generative Models
2+
3+
This repo contains code that accompanies the research paper, [Normalizing Flows are Capable Generative Models](http://arxiv.org/abs/2412.06329).
4+
5+
![Teaser image](guided_samples.jpeg)
6+
7+
# Setup
8+
9+
```bash
10+
pip install -r requirements.txt
11+
```
12+
13+
# Preparing datasets
14+
15+
Download the datasets you want to experiment with:
16+
- [Imagenet](https://www.image-net.org/download.php)
17+
- [Imagenet64](https://arxiv.org/abs/1601.06759)
18+
- [AFHQ](https://www.kaggle.com/datasets/dimensi0n/afhq-512)
19+
20+
Save the training files only in `data/<dataset>/<category>/<filename>`, the code does not use the validation/test files.
21+
22+
Compute and save stats for the true data distribution
23+
```bash
24+
# Files are saved in ./data
25+
torchrun --standalone --nproc_per_node=8 prepare_fid_stats.py --dataset=imagenet64 --img_size=64 # Unconditional
26+
torchrun --standalone --nproc_per_node=8 prepare_fid_stats.py --dataset=imagenet --img_size=64 # Conditional
27+
torchrun --standalone --nproc_per_node=8 prepare_fid_stats.py --dataset=imagenet --img_size=128 # Conditional
28+
torchrun --standalone --nproc_per_node=8 prepare_fid_stats.py --dataset=afhq --img_size=256 # Conditional
29+
```
30+
31+
Note: To run on a single GPU, replace `torchrun` with `python` like this:
32+
```bash
33+
python prepare_fid_stats.py --dataset=imagenet --img_size=64 # Conditional
34+
```
35+
36+
# Training
37+
38+
Reproducing results from the paper
39+
```bash
40+
# Unconditional ImageNet64 (8 GPUs)
41+
torchrun --standalone --nproc_per_node=8 train.py --dataset=imagenet64 --img_size=64 --channel_size=3\
42+
--patch_size=2 --channels=768 --blocks=8 --layers_per_block=8\
43+
--noise_std=0.05 --batch_size=256 --epochs=200 --lr=1e-4 --nvp\
44+
--sample_freq=5 --logdir=runs/imagenet64-uncond
45+
46+
# Conditional ImageNet64 (8 GPUs)
47+
torchrun --standalone --nproc_per_node=8 train.py --dataset=imagenet --img_size=64 --channel_size=3\
48+
--patch_size=2 --channels=768 --blocks=8 --layers_per_block=8\
49+
--noise_std=0.05 --batch_size=256 --epochs=200 --lr=1e-4 --nvp --cfg=0 --drop_label=0.1\
50+
--sample_freq=5 --logdir=runs/imagenet64-cond
51+
52+
# Conditional ImageNet128 (need to run on 4 nodes, 32 GPUs total)
53+
torchrun --standalone --nproc_per_node=8 train.py --dataset=imagenet --img_size=128 --channel_size=3\
54+
--patch_size=4 --channels=1024 --blocks=8 --layers_per_block=8\
55+
--noise_std=0.15 --batch_size=768 --epochs=320 --lr=1e-4 --nvp --cfg=0 --drop_label=0.1\
56+
--sample_freq=20 --logdir=runs/imagenet128-cond
57+
58+
# AFHQ (8 GPUs)
59+
torchrun --standalone --nproc_per_node=8 train.py --dataset=afhq --img_size=256 --channel_size=3\
60+
--patch_size=8 --channels=768 --blocks=8 --layers_per_block=8\
61+
--noise_std=0.07 --batch_size=256 --epochs=4000 --lr=1e-4 --nvp --cfg=0 --drop_label=0.1\
62+
--sample_freq=200 --logdir=runs/afhq256
63+
```
64+
65+
66+
For single-GPU
67+
```bash
68+
python train.py --dataset=imagenet64 --img_size=64 --channel_size=3\
69+
--patch_size=2 --channels=768 --blocks=8 --layers_per_block=8\
70+
--noise_std=0.05 --batch_size=32 --epochs=200 --lr=1e-4 --nvp\
71+
--sample_freq=5 --logdir=runs/imagenet64-uncond
72+
# etc...
73+
```
74+
75+
# Sampling
76+
Use the notebook to generate samples from a model checkpoint. Inside the notebook is an option to [download a pretrained checkpoint](https://ml-site.cdn-apple.com/models/tarflow/afhq256/afhq_model_8_768_8_8_0.07.pth) on AFHQ.
77+
```
78+
jupyter notebook sample.ipynb
79+
```
80+
81+
# Evaluating FID
82+
83+
Multi-GPU (8 GPUs)
84+
```bash
85+
# Conditional ImageNet64, samples saved in runs/imagenet64-cond/eval
86+
torchrun --standalone --nproc_per_node=8 evaluate_fid.py --dataset=imagenet --img_size=64 --channel_size=3\
87+
--patch_size=2 --channels=768 --blocks=8 --layers_per_block=8\
88+
--noise_std=0.05 --cfg=2.3 --nvp --batch_size=1024\
89+
--ckpt_file=runs/imagenet64-cond/imagenet_model_2_768_8_8_0.05.pth\
90+
--logdir=runs/imagenet64-cond/eval
91+
```
92+
93+
For single-GPU
94+
```bash
95+
# Conditional ImageNet64, samples saved in runs/imagenet64-cond/eval
96+
python evaluate_fid.py --dataset=imagenet --img_size=64 --channel_size=3\
97+
--patch_size=2 --channels=768 --blocks=8 --layers_per_block=8\
98+
--noise_std=0.05 --cfg=2.3 --nvp --batch_size=32\
99+
--ckpt_file=runs/imagenet64-cond/imagenet_model_2_768_8_8_0.05.pth\
100+
--logdir=runs/imagenet64-cond/eval
101+
```
102+
103+
# BibTeX
104+
```bibtex
105+
@article{zhai2024tarflow,
106+
title={Normalizing Flows are Capable Generative Models},
107+
author={Shuangfei Zhai and Ruixiang Zhang and Preetum Nakkiran and David Berthelot and Jiatao Gu and Huangjie Zheng and Tianrong Chen and Miguel Angel Bautista and Navdeep Jaitly and Josh Susskind},
108+
year={2024},
109+
eprint={2412.06329},
110+
archivePrefix={arXiv},
111+
primaryClass={cs.CV},
112+
url={https://arxiv.org/abs/2412.06329}
113+
}
114+
```

evaluate_fid.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#
2+
# For licensing see accompanying LICENSE file.
3+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
4+
#
5+
import argparse
6+
import builtins
7+
import pathlib
8+
9+
import numpy as np
10+
import torch
11+
import torch.utils.data
12+
import torchvision as tv
13+
14+
import transformer_flow
15+
import utils
16+
17+
18+
def main(args):
19+
args.denoising_batch_size = args.batch_size // 4
20+
dist = utils.Distributed()
21+
utils.set_random_seed(100 + dist.rank)
22+
num_classes = utils.get_num_classes(args.dataset)
23+
24+
def print(*args, **kwargs):
25+
if dist.local_rank == 0:
26+
builtins.print(*args, **kwargs)
27+
28+
# check if the fid stats had been previously computed
29+
fid_stats_file = f'{args.dataset}_{args.img_size}_fid_stats.pth'
30+
fid_stats_file = args.data / f'{args.dataset}_{args.img_size}_fid_stats.pth'
31+
assert fid_stats_file.exists()
32+
print(f'Loading FID stats from {fid_stats_file}')
33+
fid = utils.FID(reset_real_features=False, normalize=True).cuda()
34+
fid.load_state_dict(torch.load(fid_stats_file, map_location='cpu', weights_only=False))
35+
dist.barrier()
36+
37+
model = transformer_flow.Model(
38+
in_channels=args.channel_size,
39+
img_size=args.img_size,
40+
patch_size=args.patch_size,
41+
channels=args.channels,
42+
num_blocks=args.blocks,
43+
layers_per_block=args.layers_per_block,
44+
nvp=args.nvp,
45+
num_classes=num_classes,
46+
).cuda()
47+
for p in model.parameters():
48+
p.requires_grad = False
49+
50+
model_name = f'{args.patch_size}_{args.channels}_{args.blocks}_{args.layers_per_block}_{args.noise_std:.2f}'
51+
sample_dir: pathlib.Path = args.logdir / f'{args.dataset}_samples_{model_name}'
52+
53+
if dist.local_rank == 0:
54+
sample_dir.mkdir(parents=True, exist_ok=True)
55+
56+
ckpt = torch.load(args.ckpt_file, map_location='cpu', weights_only=True)
57+
model.load_state_dict(ckpt, strict=True)
58+
model.eval()
59+
60+
print('Starting sampling')
61+
num_batches = int(np.ceil(args.num_samples / args.batch_size))
62+
last_batch_size = args.num_samples - (num_batches - 1) * args.batch_size
63+
64+
def get_noise(b):
65+
return torch.randn(
66+
b, (args.img_size // args.patch_size) ** 2, args.channel_size * args.patch_size**2, device='cuda'
67+
)
68+
69+
for i in range(num_batches):
70+
noise = get_noise(args.batch_size // dist.world_size)
71+
if num_classes:
72+
y = torch.randint(num_classes, (args.batch_size // dist.world_size,), device='cuda')
73+
else:
74+
y = None
75+
while True:
76+
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
77+
samples = model.reverse(noise, y, args.cfg, attn_temp=args.attn_temp, annealed_guidance=True)
78+
assert isinstance(samples, torch.Tensor)
79+
80+
if args.self_denoising_lr > 0:
81+
samples = samples.cpu()
82+
assert args.batch_size % args.denoising_batch_size == 0
83+
db = args.denoising_batch_size // dist.world_size
84+
# This should be the theoretical optimal denoising lr
85+
base_lr = db * args.img_size**2 * args.channel_size * args.noise_std**2
86+
lr = args.self_denoising_lr * base_lr
87+
denoised_samples = []
88+
for j in range(args.batch_size // args.denoising_batch_size):
89+
x = torch.clone(samples[j * db : (j + 1) * db]).detach().cuda()
90+
x.requires_grad = True
91+
y_ = y[j * db : (j + 1) * db] if y is not None else None
92+
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
93+
z, _, logdets = model(x, y_)
94+
loss = model.get_loss(z, logdets)
95+
grad = torch.autograd.grad(loss, [x])[0]
96+
x.data.add_(grad, alpha=-lr)
97+
denoised_samples.append(x.detach().cpu())
98+
samples = torch.cat(denoised_samples, dim=0).cuda()
99+
100+
samples = dist.gather_concat(samples.detach())
101+
if not samples.isnan().any().item():
102+
break
103+
else:
104+
noise = get_noise(args.batch_size // dist.world_size)
105+
106+
if i == num_batches - 1:
107+
samples = samples[:last_batch_size]
108+
109+
fid.update(0.5 * (samples.clip(min=-1, max=1) + 1), real=False)
110+
print(f'{i+1}/{num_batches} batch sample complete')
111+
fid_score = fid.compute().item()
112+
fid.reset()
113+
114+
print(f'{args.ckpt_file} {model_name} cfg {args.cfg:.2f} fid {fid_score:.2f}')
115+
if dist.local_rank == 0:
116+
tv.utils.save_image(samples, sample_dir / f'samples_cfg{args.cfg:.2f}.png', normalize=True, nrow=16)
117+
dist.barrier()
118+
119+
120+
if __name__ == '__main__':
121+
parser = argparse.ArgumentParser()
122+
parser.add_argument('--data', default='data', type=pathlib.Path, help='Path for training data')
123+
parser.add_argument('--logdir', default='runs', type=pathlib.Path, help='Path for artifacts')
124+
125+
parser.add_argument('--ckpt_file', default='', type=str, help='Path for checkpoint for evaluation')
126+
parser.add_argument('--dataset', default='imagenet', type=str, choices=['imagenet', 'imagenet64', 'afhq'], help='Name of dataset')
127+
parser.add_argument('--img_size', default=32, type=int, help='Image size')
128+
parser.add_argument('--channel_size', default=3, type=int, help='Image channel size')
129+
130+
parser.add_argument('--patch_size', default=4, type=int, help='Patch size for the model')
131+
parser.add_argument('--channels', default=512, type=int, help='Model width')
132+
parser.add_argument('--blocks', default=4, type=int, help='Number of autoregressive flow blocks')
133+
parser.add_argument('--layers_per_block', default=8, type=int, help='Depth per flow block')
134+
parser.add_argument('--noise_std', default=0.05, type=float, help='Input noise standard deviation')
135+
parser.add_argument('--nvp', default=True, action=argparse.BooleanOptionalAction, help='Whether to use the non volume preserving version')
136+
parser.add_argument('--cfg', default=0, type=float, help='Guidance weight for sampling, 0 is no guidance. For conditional models consider the range in [1, 3]')
137+
parser.add_argument('--attn_temp', default=1.0, type=float, help='Attention temperature for unconditional guidance, enabled when not 1 (eg, 0.5, 1.5)')
138+
parser.add_argument('--batch_size', default=1024, type=int, help='Batch size for drawing samples')
139+
parser.add_argument('--num_samples', default=50000, type=int, help='Number of total samples to draw')
140+
parser.add_argument('--self_denoising_lr', default=1.0, type=float, help='Learning rate multiplier for denoising, 1 is the theoretical optimal one')
141+
142+
args = parser.parse_args()
143+
144+
main(args)

guided_samples.jpeg

1.26 MB
Loading

0 commit comments

Comments
 (0)