@@ -34,6 +34,7 @@ def __init__(
34
34
warmup_epochs = 100 ,
35
35
warmup_steps = None ,
36
36
warmup_lr_init = 0 ,
37
+ decay_steps = None ,
37
38
batch_size = 32 ,
38
39
eval_bs = 32 ,
39
40
test_bs = 64 ,
@@ -130,7 +131,7 @@ def __init__(
130
131
if self .accelerator .is_main_process :
131
132
print (f"Effective batch size is { effective_bs } " )
132
133
133
- self .g_optim = create_optimizer (self .model , weight_decay = 0.05 , learning_rate = lr , accelerator = self .accelerator )
134
+ self .g_optim = create_optimizer (self .model , weight_decay = 0.05 , learning_rate = lr ,) # accelerator=self.accelerator)
134
135
135
136
if warmup_epochs is not None :
136
137
warmup_steps = warmup_epochs * len (self .train_dl )
@@ -142,6 +143,7 @@ def __init__(
142
143
lr_min ,
143
144
warmup_steps ,
144
145
warmup_lr_init ,
146
+ decay_steps ,
145
147
cosine_lr
146
148
)
147
149
self .accelerator .register_for_checkpointing (self .g_sched )
@@ -232,7 +234,7 @@ def _load_checkpoint(self, ckpt_path=None):
232
234
print (f"Loaded checkpoint from { ckpt_path } " )
233
235
234
236
def train (self , config = None ):
235
- n_parameters = sum (p .numel () for p in self .parameters () if p .requires_grad )
237
+ n_parameters = sum (p .numel () for p in self .model . parameters () if p .requires_grad )
236
238
if self .accelerator .is_main_process :
237
239
print (f"number of learnable parameters: { n_parameters // 1e6 } M" )
238
240
if config is not None :
@@ -293,7 +295,6 @@ def train(self, config=None):
293
295
self .accelerator .backward (loss )
294
296
if self .accelerator .sync_gradients and self .max_grad_norm is not None :
295
297
self .accelerator .clip_grad_norm_ (self .model .parameters (), self .max_grad_norm )
296
- self .accelerator .unwrap_model (self .model ).cancel_gradients_encoder (epoch )
297
298
self .g_optim .step ()
298
299
if self .g_sched is not None :
299
300
self .g_sched .step_update (self .steps )
@@ -355,7 +356,7 @@ def evaluate(self):
355
356
img = batch
356
357
357
358
with self .accelerator .autocast ():
358
- rec = self .model (img , targets , sample = True , inference_with_n_slots = self .test_num_slots , cfg = 1.0 )
359
+ rec = self .model (img , sample = True , inference_with_n_slots = self .test_num_slots , cfg = 1.0 )
359
360
imgs_and_recs = torch .stack ((img .to (rec .device ), rec ), dim = 0 )
360
361
imgs_and_recs = rearrange (imgs_and_recs , "r b ... -> (b r) ..." )
361
362
imgs_and_recs = imgs_and_recs .detach ().cpu ().float ()
@@ -373,7 +374,7 @@ def evaluate(self):
373
374
374
375
if self .cfg != 1.0 :
375
376
with self .accelerator .autocast ():
376
- rec = self .model (img , targets , sample = True , inference_with_n_slots = self .test_num_slots , cfg = self .cfg )
377
+ rec = self .model (img , sample = True , inference_with_n_slots = self .test_num_slots , cfg = self .cfg )
377
378
378
379
imgs_and_recs = torch .stack ((img .to (rec .device ), rec ), dim = 0 )
379
380
imgs_and_recs = rearrange (imgs_and_recs , "r b ... -> (b r) ..." )
@@ -417,7 +418,7 @@ def process_batch(cfg_value, save_dir, header):
417
418
targets = targets .to (self .device , non_blocking = True )
418
419
419
420
with self .accelerator .autocast ():
420
- recs = self .model (imgs , targets , sample = True , inference_with_n_slots = self .test_num_slots , cfg = cfg_value )
421
+ recs = self .model (imgs , sample = True , inference_with_n_slots = self .test_num_slots , cfg = cfg_value )
421
422
422
423
psnr_val = psnr (recs , imgs , data_range = 1.0 )
423
424
ssim_val = ssim (recs , imgs , data_range = 1.0 )
0 commit comments