|
| 1 | +# |
| 2 | +# For licensing see accompanying LICENSE file. |
| 3 | +# Copyright (C) 2024 Apple Inc. All Rights Reserved. |
| 4 | +# |
| 5 | +import argparse |
| 6 | +import builtins |
| 7 | +import pathlib |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +import torch |
| 11 | +import torch.utils.data |
| 12 | +import torchvision as tv |
| 13 | + |
| 14 | +import transformer_flow |
| 15 | +import utils |
| 16 | + |
| 17 | + |
| 18 | +def main(args): |
| 19 | + args.denoising_batch_size = args.batch_size // 4 |
| 20 | + dist = utils.Distributed() |
| 21 | + utils.set_random_seed(100 + dist.rank) |
| 22 | + num_classes = utils.get_num_classes(args.dataset) |
| 23 | + |
| 24 | + def print(*args, **kwargs): |
| 25 | + if dist.local_rank == 0: |
| 26 | + builtins.print(*args, **kwargs) |
| 27 | + |
| 28 | + # check if the fid stats had been previously computed |
| 29 | + fid_stats_file = f'{args.dataset}_{args.img_size}_fid_stats.pth' |
| 30 | + fid_stats_file = args.data / f'{args.dataset}_{args.img_size}_fid_stats.pth' |
| 31 | + assert fid_stats_file.exists() |
| 32 | + print(f'Loading FID stats from {fid_stats_file}') |
| 33 | + fid = utils.FID(reset_real_features=False, normalize=True).cuda() |
| 34 | + fid.load_state_dict(torch.load(fid_stats_file, map_location='cpu', weights_only=False)) |
| 35 | + dist.barrier() |
| 36 | + |
| 37 | + model = transformer_flow.Model( |
| 38 | + in_channels=args.channel_size, |
| 39 | + img_size=args.img_size, |
| 40 | + patch_size=args.patch_size, |
| 41 | + channels=args.channels, |
| 42 | + num_blocks=args.blocks, |
| 43 | + layers_per_block=args.layers_per_block, |
| 44 | + nvp=args.nvp, |
| 45 | + num_classes=num_classes, |
| 46 | + ).cuda() |
| 47 | + for p in model.parameters(): |
| 48 | + p.requires_grad = False |
| 49 | + |
| 50 | + model_name = f'{args.patch_size}_{args.channels}_{args.blocks}_{args.layers_per_block}_{args.noise_std:.2f}' |
| 51 | + sample_dir: pathlib.Path = args.logdir / f'{args.dataset}_samples_{model_name}' |
| 52 | + |
| 53 | + if dist.local_rank == 0: |
| 54 | + sample_dir.mkdir(parents=True, exist_ok=True) |
| 55 | + |
| 56 | + ckpt = torch.load(args.ckpt_file, map_location='cpu', weights_only=True) |
| 57 | + model.load_state_dict(ckpt, strict=True) |
| 58 | + model.eval() |
| 59 | + |
| 60 | + print('Starting sampling') |
| 61 | + num_batches = int(np.ceil(args.num_samples / args.batch_size)) |
| 62 | + last_batch_size = args.num_samples - (num_batches - 1) * args.batch_size |
| 63 | + |
| 64 | + def get_noise(b): |
| 65 | + return torch.randn( |
| 66 | + b, (args.img_size // args.patch_size) ** 2, args.channel_size * args.patch_size**2, device='cuda' |
| 67 | + ) |
| 68 | + |
| 69 | + for i in range(num_batches): |
| 70 | + noise = get_noise(args.batch_size // dist.world_size) |
| 71 | + if num_classes: |
| 72 | + y = torch.randint(num_classes, (args.batch_size // dist.world_size,), device='cuda') |
| 73 | + else: |
| 74 | + y = None |
| 75 | + while True: |
| 76 | + with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
| 77 | + samples = model.reverse(noise, y, args.cfg, attn_temp=args.attn_temp, annealed_guidance=True) |
| 78 | + assert isinstance(samples, torch.Tensor) |
| 79 | + |
| 80 | + if args.self_denoising_lr > 0: |
| 81 | + samples = samples.cpu() |
| 82 | + assert args.batch_size % args.denoising_batch_size == 0 |
| 83 | + db = args.denoising_batch_size // dist.world_size |
| 84 | + # This should be the theoretical optimal denoising lr |
| 85 | + base_lr = db * args.img_size**2 * args.channel_size * args.noise_std**2 |
| 86 | + lr = args.self_denoising_lr * base_lr |
| 87 | + denoised_samples = [] |
| 88 | + for j in range(args.batch_size // args.denoising_batch_size): |
| 89 | + x = torch.clone(samples[j * db : (j + 1) * db]).detach().cuda() |
| 90 | + x.requires_grad = True |
| 91 | + y_ = y[j * db : (j + 1) * db] if y is not None else None |
| 92 | + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
| 93 | + z, _, logdets = model(x, y_) |
| 94 | + loss = model.get_loss(z, logdets) |
| 95 | + grad = torch.autograd.grad(loss, [x])[0] |
| 96 | + x.data.add_(grad, alpha=-lr) |
| 97 | + denoised_samples.append(x.detach().cpu()) |
| 98 | + samples = torch.cat(denoised_samples, dim=0).cuda() |
| 99 | + |
| 100 | + samples = dist.gather_concat(samples.detach()) |
| 101 | + if not samples.isnan().any().item(): |
| 102 | + break |
| 103 | + else: |
| 104 | + noise = get_noise(args.batch_size // dist.world_size) |
| 105 | + |
| 106 | + if i == num_batches - 1: |
| 107 | + samples = samples[:last_batch_size] |
| 108 | + |
| 109 | + fid.update(0.5 * (samples.clip(min=-1, max=1) + 1), real=False) |
| 110 | + print(f'{i+1}/{num_batches} batch sample complete') |
| 111 | + fid_score = fid.compute().item() |
| 112 | + fid.reset() |
| 113 | + |
| 114 | + print(f'{args.ckpt_file} {model_name} cfg {args.cfg:.2f} fid {fid_score:.2f}') |
| 115 | + if dist.local_rank == 0: |
| 116 | + tv.utils.save_image(samples, sample_dir / f'samples_cfg{args.cfg:.2f}.png', normalize=True, nrow=16) |
| 117 | + dist.barrier() |
| 118 | + |
| 119 | + |
| 120 | +if __name__ == '__main__': |
| 121 | + parser = argparse.ArgumentParser() |
| 122 | + parser.add_argument('--data', default='data', type=pathlib.Path, help='Path for training data') |
| 123 | + parser.add_argument('--logdir', default='runs', type=pathlib.Path, help='Path for artifacts') |
| 124 | + |
| 125 | + parser.add_argument('--ckpt_file', default='', type=str, help='Path for checkpoint for evaluation') |
| 126 | + parser.add_argument('--dataset', default='imagenet', type=str, choices=['imagenet', 'imagenet64', 'afhq'], help='Name of dataset') |
| 127 | + parser.add_argument('--img_size', default=32, type=int, help='Image size') |
| 128 | + parser.add_argument('--channel_size', default=3, type=int, help='Image channel size') |
| 129 | + |
| 130 | + parser.add_argument('--patch_size', default=4, type=int, help='Patch size for the model') |
| 131 | + parser.add_argument('--channels', default=512, type=int, help='Model width') |
| 132 | + parser.add_argument('--blocks', default=4, type=int, help='Number of autoregressive flow blocks') |
| 133 | + parser.add_argument('--layers_per_block', default=8, type=int, help='Depth per flow block') |
| 134 | + parser.add_argument('--noise_std', default=0.05, type=float, help='Input noise standard deviation') |
| 135 | + parser.add_argument('--nvp', default=True, action=argparse.BooleanOptionalAction, help='Whether to use the non volume preserving version') |
| 136 | + parser.add_argument('--cfg', default=0, type=float, help='Guidance weight for sampling, 0 is no guidance. For conditional models consider the range in [1, 3]') |
| 137 | + parser.add_argument('--attn_temp', default=1.0, type=float, help='Attention temperature for unconditional guidance, enabled when not 1 (eg, 0.5, 1.5)') |
| 138 | + parser.add_argument('--batch_size', default=1024, type=int, help='Batch size for drawing samples') |
| 139 | + parser.add_argument('--num_samples', default=50000, type=int, help='Number of total samples to draw') |
| 140 | + parser.add_argument('--self_denoising_lr', default=1.0, type=float, help='Learning rate multiplier for denoising, 1 is the theoretical optimal one') |
| 141 | + |
| 142 | + args = parser.parse_args() |
| 143 | + |
| 144 | + main(args) |
0 commit comments