Skip to content

Commit 2ae9ec1

Browse files
committed
Add titok
1 parent d0754ac commit 2ae9ec1

File tree

13 files changed

+857
-58
lines changed

13 files changed

+857
-58
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Implement visual tokenizers with PyTorch.
1919
- [ ] Index Backpropogate Quantization (IBQ)
2020
- [ ] Grouped Spherical Quantization (GSQ)
2121

22-
**ImageNet 256x256 Re-implementation**:
22+
**ImageNet 256x256 Reproduction**:
2323

2424
- [x] VQGAN (Taming-Transformers)
2525
- [x] VQGAN (LlamaGen)

Diff for: configs/imagenet256/vqgan-titok.yaml

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
seed: 8888
2+
3+
data:
4+
name: imagenet
5+
root: ~/data/ImageNet/ILSVRC2012/Images
6+
img_size: 256
7+
crop: random
8+
9+
dataloader:
10+
num_workers: 4
11+
pin_memory: true
12+
prefetch_factor: 2
13+
14+
encoder:
15+
target: models.autoencoder.titok_net.Encoder
16+
params:
17+
in_channels: 3
18+
image_size: 256
19+
patch_size: 16
20+
embed_dim: 768 # base
21+
n_heads: 12 # base
22+
n_layers: 12 # base
23+
n_tokens: 64
24+
25+
decoder:
26+
target: models.autoencoder.titok_net.Decoder
27+
params:
28+
out_channels: 3
29+
image_size: 256
30+
patch_size: 16
31+
embed_dim: 768 # base
32+
n_heads: 12 # base
33+
n_layers: 12 # base
34+
n_tokens: 64
35+
36+
quantizer:
37+
target: models.quantizer.VectorQuantizer
38+
params:
39+
codebook_num: 4096
40+
codebook_dim: 12
41+
l2_norm: True
42+
43+
disc:
44+
target: models.discriminator.TitokGANDiscriminator
45+
46+
train:
47+
n_steps: 1000000
48+
batch_size: 256
49+
micro_batch_size: ~
50+
51+
type_rec: l2
52+
coef_rec: 1.0
53+
54+
coef_lpips: 1.0
55+
56+
type_perc: convnext_s
57+
coef_perc: 0.1
58+
59+
coef_commit: 0.25
60+
coef_vq: 1.0
61+
62+
coef_adv: 0.1
63+
start_adv: 200000
64+
coef_lecam_reg: 0.001
65+
66+
ema:
67+
decay: 0.9999
68+
ema_warmup_type: crowsonkb
69+
70+
clip_grad_norm: 1.0
71+
72+
print_freq: 500
73+
sample_freq: 10000
74+
save_freq: 50000
75+
76+
optim:
77+
target: torch.optim.AdamW
78+
params:
79+
lr: 0.0001
80+
betas: [0.9, 0.999]
81+
weight_decay: 0.0001
82+
83+
optim_d:
84+
target: torch.optim.AdamW
85+
params:
86+
lr: 0.0001
87+
betas: [0.9, 0.999]
88+
weight_decay: 0.0001
89+
90+
sched:
91+
target: utils.scheduler.CosineMinimumWarmupLR
92+
params:
93+
warmup_steps: 10000
94+
training_steps: 1000000
95+
min_lr: 0.00001

Diff for: docs/benchmark-imagenet256.md

+18-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ This benchmark aims to reproduce the results reported in the papers as closely a
88

99

1010

11-
## Quantitative results
11+
## VQGAN (Taming Transformers)
1212

13-
Using hyperparameters from ["Taming Transformers"](http://arxiv.org/abs/2012.09841) paper (see [config](../configs/imagenet256/vqgan-taming.yaml)):
13+
[[paper]](http://arxiv.org/abs/2012.09841) [[config]](../configs/imagenet256/vqgan-taming.yaml)
1414

1515
| Downsample ratio | Codebook dim. | Codebook size | Codebook usage↑ | PSNR↑ | SSIM↑ | LPIPS↓ | rFID↓ |
1616
|:-------------------:|:-------------:|:-------------:|:---------------:|:-------:|:------:|:------:|:------:|
@@ -22,11 +22,26 @@ Using hyperparameters from ["Taming Transformers"](http://arxiv.org/abs/2012.098
2222

2323

2424

25-
Using hyperparameters from ["LlamaGen"](http://arxiv.org/abs/2406.06525) paper (see [config](../configs/imagenet256/vqgan-llamagen.yaml)):
25+
## VQGAN (LlamaGen)
26+
27+
[[paper]](http://arxiv.org/abs/2406.06525) [[config]](../configs/imagenet256/vqgan-llamagen.yaml)
2628

2729
| Downsample ratio | Codebook dim. | Codebook size | Codebook usage↑ | PSNR↑ | SSIM↑ | LPIPS↓ | rFID↓ |
2830
|:------------------:|:-------------:|:-------------:|:---------------:|:-------:|:------:|:------:|:------:|
2931
| 16 | 8 | 16384 | 100% | 20.7201 | 0.5509 | 0.1385 | 2.1073 |
3032

3133
- ️🌱 The PSNR is close to the results reported in the paper (20.79).
3234
- ️🌱 The rFID is even slightly better than the results reported in the paper (2.19).
35+
36+
37+
38+
## TiTok
39+
40+
[[paper]](https://arxiv.org/abs/2406.07550) [[project page]](https://yucornetto.github.io/projects/titok.html) [[config]](../configs/imagenet256/vqgan-titok.yaml)
41+
42+
| \# tokens | Codebook dim. | Codebook size | Codebook usage↑ | PSNR↑ | SSIM↑ | LPIPS↓ | rFID↓ |
43+
|:---------:|:-------------:|:-------------:|:---------------:|:-------:|:------:|:------:|:------:|
44+
| 64 | 12 | 4096 | 100% | 17.8995 | 0.4022 | 0.2681 | 4.6691 |
45+
46+
- ⚠️ The model is trained with a single-stage training strategy, which is different from the paper.
47+
- ⚠️ The results are not good. Reconstructed images contain repeated patterns and artifacts. Need further investigation.

Diff for: losses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .adversarial import AdversarialLoss
22
from .lpips import LPIPS as LPIPSLoss
3+
from .perceptual_loss import PerceptualLoss

Diff for: losses/adversarial.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,35 @@ class AdversarialLoss(nn.Module):
1515
Objective of the discriminator: min E[max(0, 1-D(x))] + E[max(0, 1+D(G(z)))]
1616
Objective of the generator: min -E[D(G(z))]
1717
18+
Supports LeCam regularization on the discriminator.
19+
1820
"""
19-
def __init__(self, discriminator: nn.Module, loss_type: str):
21+
def __init__(
22+
self,
23+
discriminator: nn.Module,
24+
loss_type: str,
25+
coef_lecam_reg: float = 0.0,
26+
lecam_reg_ema_decay: float = 0.999,
27+
):
2028
super().__init__()
2129
assert loss_type in ['ns', 'hinge']
2230

2331
self.discriminator = discriminator
2432
self.loss_type = loss_type
2533

34+
self.coef_lecam_reg = coef_lecam_reg
35+
if self.coef_lecam_reg > 0.0:
36+
self.lecam_reg_ema_decay = lecam_reg_ema_decay
37+
self.register_buffer('ema_real_logits_mean', torch.zeros(1))
38+
self.register_buffer('ema_fake_logits_mean', torch.zeros(1))
39+
40+
def lecam_reg(self, real_logits_mean: Tensor, fake_logits_mean: Tensor):
41+
lecam_loss = (torch.mean(torch.pow(F.relu(real_logits_mean - self.ema_fake_logits_mean), 2)) +
42+
torch.mean(torch.pow(F.relu(self.ema_real_logits_mean - fake_logits_mean), 2)))
43+
self.ema_real_logits_mean = self.ema_real_logits_mean * self.lecam_reg_ema_decay + real_logits_mean.detach() * (1 - self.lecam_reg_ema_decay) # noqa
44+
self.ema_fake_logits_mean = self.ema_fake_logits_mean * self.lecam_reg_ema_decay + fake_logits_mean.detach() * (1 - self.lecam_reg_ema_decay) # noqa
45+
return lecam_loss
46+
2647
def forward_G(self, fake_data: Tensor, *args, **kwargs):
2748
fake_logits = self.discriminator(fake_data, *args, **kwargs)
2849
if self.loss_type == 'ns':
@@ -45,6 +66,9 @@ def forward_D(self, fake_data: Tensor, real_data: Tensor, *args, **kwargs):
4566
else:
4667
raise ValueError(f'Unknown loss type: {self.loss_type}')
4768

69+
if self.coef_lecam_reg > 0.0:
70+
loss = loss + self.coef_lecam_reg * self.lecam_reg(real_logits.mean(), fake_logits.mean())
71+
4872
return loss
4973

5074
def forward(self, mode: str, fake_data: Tensor, real_data: Tensor = None, *args, **kwargs):

Diff for: losses/perceptual_loss.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Perceptual loss.
2+
3+
References:
4+
- https://github.com/bytedance/1d-tokenizer/blob/main/modeling/modules/perceptual_loss.py
5+
- https://github.com/markweberdev/maskbit/blob/main/modeling/modules/perceptual_loss.py
6+
"""
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from torchvision import models
11+
12+
13+
class PerceptualLoss(torch.nn.Module):
14+
def __init__(self, model_name: str = 'resnet50'):
15+
super().__init__()
16+
if model_name == 'resnet50':
17+
self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).eval()
18+
elif model_name == "convnext_s":
19+
self.model = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).eval()
20+
else:
21+
raise ValueError(f'Unsupported model name: {model_name}')
22+
23+
self.register_buffer('imagenet_mean', torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None])
24+
self.register_buffer('imagenet_std', torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None])
25+
26+
for param in self.parameters():
27+
param.requires_grad = False
28+
29+
def forward(self, image1: torch.Tensor, image2: torch.Tensor):
30+
"""Computes the perceptual loss.
31+
32+
Args:
33+
image1: A tensor of shape (B, C, H, W) in range [0, 1].
34+
image2: A tensor of shape (B, C, H, W) in range [0, 1].
35+
36+
Returns:
37+
A scalar tensor, the perceptual loss.
38+
"""
39+
image1 = F.interpolate(image1, size=224, mode='bilinear', align_corners=False, antialias=True)
40+
image2 = F.interpolate(image2, size=224, mode='bilinear', align_corners=False, antialias=True)
41+
42+
image1 = (image1 - self.imagenet_mean) / self.imagenet_std
43+
image2 = (image2 - self.imagenet_mean) / self.imagenet_std
44+
45+
pred1 = self.model(image1)
46+
pred2 = self.model(image2)
47+
48+
loss = F.mse_loss(pred1, pred2, reduction='mean')
49+
return loss

0 commit comments

Comments
 (0)