Skip to content

Commit e793111

Browse files
authored
Merge pull request #1 from visual-gen/fix
Fix
2 parents 4ddb566 + dd0c182 commit e793111

File tree

4 files changed

+27
-9
lines changed

4 files changed

+27
-9
lines changed

Diff for: semanticist/engine/diffusion_trainer.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
warmup_epochs=100,
3535
warmup_steps=None,
3636
warmup_lr_init=0,
37+
decay_steps=None,
3738
batch_size=32,
3839
eval_bs=32,
3940
test_bs=64,
@@ -130,7 +131,7 @@ def __init__(
130131
if self.accelerator.is_main_process:
131132
print(f"Effective batch size is {effective_bs}")
132133

133-
self.g_optim = create_optimizer(self.model, weight_decay=0.05, learning_rate=lr, accelerator=self.accelerator)
134+
self.g_optim = create_optimizer(self.model, weight_decay=0.05, learning_rate=lr,) # accelerator=self.accelerator)
134135

135136
if warmup_epochs is not None:
136137
warmup_steps = warmup_epochs * len(self.train_dl)
@@ -142,6 +143,7 @@ def __init__(
142143
lr_min,
143144
warmup_steps,
144145
warmup_lr_init,
146+
decay_steps,
145147
cosine_lr
146148
)
147149
self.accelerator.register_for_checkpointing(self.g_sched)
@@ -232,7 +234,7 @@ def _load_checkpoint(self, ckpt_path=None):
232234
print(f"Loaded checkpoint from {ckpt_path}")
233235

234236
def train(self, config=None):
235-
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
237+
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
236238
if self.accelerator.is_main_process:
237239
print(f"number of learnable parameters: {n_parameters//1e6}M")
238240
if config is not None:
@@ -293,7 +295,6 @@ def train(self, config=None):
293295
self.accelerator.backward(loss)
294296
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
295297
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
296-
self.accelerator.unwrap_model(self.model).cancel_gradients_encoder(epoch)
297298
self.g_optim.step()
298299
if self.g_sched is not None:
299300
self.g_sched.step_update(self.steps)
@@ -355,7 +356,7 @@ def evaluate(self):
355356
img = batch
356357

357358
with self.accelerator.autocast():
358-
rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
359+
rec = self.model(img, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
359360
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
360361
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
361362
imgs_and_recs = imgs_and_recs.detach().cpu().float()
@@ -373,7 +374,7 @@ def evaluate(self):
373374

374375
if self.cfg != 1.0:
375376
with self.accelerator.autocast():
376-
rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
377+
rec = self.model(img, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
377378

378379
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
379380
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
@@ -417,7 +418,7 @@ def process_batch(cfg_value, save_dir, header):
417418
targets = targets.to(self.device, non_blocking=True)
418419

419420
with self.accelerator.autocast():
420-
recs = self.model(imgs, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
421+
recs = self.model(imgs, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
421422

422423
psnr_val = psnr(recs, imgs, data_range=1.0)
423424
ssim_val = ssim(recs, imgs, data_range=1.0)

Diff for: semanticist/engine/gpt_trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
warmup_epochs=100,
4040
warmup_steps=None,
4141
warmup_lr_init=0,
42+
decay_steps=None,
4243
batch_size=32,
4344
cache_bs=8,
4445
test_bs=100,
@@ -137,6 +138,7 @@ def __init__(
137138
lr_min,
138139
warmup_steps,
139140
warmup_lr_init,
141+
decay_steps,
140142
cosine_lr
141143
)
142144
self.accelerator.register_for_checkpointing(self.g_sched)

Diff for: semanticist/stage1/diffuse_slot.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
torch.nn.init.normal_(self.null_cond, std=.02)
2929
self.autoenc_cond_embedder = nn.Linear(autoenc_dim, self.hidden_size)
3030
self.y_embedder = nn.Identity()
31+
self.cond_drop_prob = 0.1
3132

3233
self.use_repa = use_repa
3334
self._repa_hook = None
@@ -39,7 +40,21 @@ def embed_cond(self, autoenc_cond, drop_mask=None):
3940
# autoenc_cond: (N, K, D)
4041
# drop_ids: (N)
4142
# self.null_cond: (1, K, D)
42-
autoenc_cond_drop = torch.where(drop_mask[:, :, None], autoenc_cond, self.null_cond)
43+
batch_size = autoenc_cond.shape[0]
44+
if drop_mask is None:
45+
# randomly drop all conditions, for classifier-free guidance
46+
if self.training:
47+
drop_ids = (
48+
torch.rand(batch_size, 1, 1, device=autoenc_cond.device)
49+
< self.cond_drop_prob
50+
)
51+
autoenc_cond_drop = torch.where(drop_ids, self.null_cond, autoenc_cond)
52+
else:
53+
autoenc_cond_drop = autoenc_cond
54+
else:
55+
# randomly drop some conditions according to the drop_mask (N, K)
56+
# True means keep
57+
autoenc_cond_drop = torch.where(drop_mask[:, :, None], autoenc_cond, self.null_cond)
4358
return self.autoenc_cond_embedder(autoenc_cond_drop)
4459

4560
def forward(self, x, t, autoenc_cond, drop_mask=None):
@@ -75,7 +90,7 @@ def forward_with_cfg(self, x, t, autoenc_cond, drop_mask, y=None, cfg_scale=1.0)
7590
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
7691
half = x[: len(x) // 2]
7792
combined = torch.cat([half, half], dim=0)
78-
model_out = self.forward(combined, t, autoenc_cond, drop_mask, y)
93+
model_out = self.forward(combined, t, autoenc_cond, drop_mask)
7994
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
8095
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
8196
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)

Diff for: train_net.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os.path as osp
22
import argparse
33
from omegaconf import OmegaConf
4-
from semanticist.engine.util import instantiate_from_config
4+
from semanticist.engine.trainer_utils import instantiate_from_config
55
from semanticist.utils.device_utils import configure_compute_backend
66

77
def train():

0 commit comments

Comments
 (0)