Skip to content

Commit fe19f47

Browse files
committed
updates to trainer
1 parent 8e8021e commit fe19f47

File tree

2 files changed

+45
-25
lines changed

2 files changed

+45
-25
lines changed

diffusion_models/utils/trainer.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Callable, Literal
1212
import wandb
1313

14-
class DiscriminativeTrainer:
14+
class Trainer:
1515
"""Trainer Class that trains 1 model instance on 1 device."""
1616
def __init__(
1717
self,
@@ -74,31 +74,18 @@ def _setup_model(self, model):
7474
def _setup_dataloader(self, dataset):
7575
return DataLoader(dataset, batch_size=self.batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset))
7676

77-
def _run_batch(self, source, targets):
78-
self.optimizer.zero_grad()
79-
pred = self.model(source)
80-
loss = self.loss_func(pred, targets)
81-
loss.backward()
82-
self.optimizer.step()
83-
return loss.item()
84-
85-
def _run_epoch_nonCuda(self, epoch):
86-
epoch_losses = []
87-
time1 = time()
88-
for source, targets in self.train_data:
89-
source, targets = source.to(self.device_type), targets.to(self.device_type)
90-
batch_loss = self._run_batch(source, targets)
91-
epoch_losses.append(batch_loss)
92-
if self.log_wandb:
93-
wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses), "epoch_time": time()-time1})
94-
print(f"[{self.device_type}{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses)} | Time: {time()-time1:.2f}s")
77+
def _run_batch(self, data):
78+
raise NotImplementedError("use dedicated subclass")
9579

9680
def _run_epoch(self, epoch):
9781
epoch_losses = []
9882
time1 = time()
9983
for source, targets in self.train_data:
100-
source, targets = source.to(self.gpu_id), targets.to(self.gpu_id)
101-
batch_loss = self._run_batch(source, targets)
84+
if self.device_type == "cuda":
85+
data = map(lambda x: x.to(self.gpu_id))
86+
else:
87+
data = map(lambda x: x.to(self.device_type))
88+
batch_loss = self._run_batch(data)
10289
epoch_losses.append(batch_loss)
10390
if self.log_wandb:
10491
wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses), "epoch_time": time()-time1})
@@ -126,4 +113,25 @@ def train(self, max_epochs: int):
126113
else:
127114
self._run_epoch(epoch)
128115
if (self.gpu_id == 0) and (epoch % self.save_every == 0) and (epoch != 0):
129-
self._save_checkpoint(epoch)
116+
self._save_checkpoint(epoch)
117+
118+
class DiscriminativeTrainer(Trainer):
119+
def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[..., Any], optimizer: Optimizer, gpu_id: int, batch_size: int, save_every: int, checkpoint_folder: str, device_type: Literal['cuda', 'mps', 'cpu'], log_wandb: bool = True) -> None:
120+
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
121+
122+
def _run_batch(self, data):
123+
source, targets = data
124+
self.optimizer.zero_grad()
125+
pred = self.model(source)
126+
loss = self.loss_func(pred, targets)
127+
loss.backward()
128+
self.optimizer.step()
129+
return loss.item()
130+
131+
class GenerativeTrainer(Trainer):
132+
def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[..., Any], optimizer: Optimizer, gpu_id: int, batch_size: int, save_every: int, checkpoint_folder: str, device_type: Literal['cuda', 'mps', 'cpu'], log_wandb: bool = True) -> None:
133+
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
134+
135+
def _run_batch(self, data):
136+
self.optimizer.zero_grad()
137+
raise NotImplementedError("not finished yet")

tests/train_parallel.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
total_epochs = 5,
2020
batch_size = 1000,
2121
learning_rate = 0.001,
22-
device_type = "cuda",
22+
device_type = "cpu",
2323
dataloader = MNISTTrainLoader,
2424
architecture = MNISTEncoder,
2525
out_classes = 10,
@@ -30,7 +30,8 @@
3030
data_path = "/itet-stor/peerli/net_scratch",
3131
checkpoint_folder = "/itet-stor/peerli/net_scratch/mnist_checkpoints",
3232
save_every = 10,
33-
loss_func = F.cross_entropy
33+
loss_func = F.cross_entropy,
34+
log_wandb = False
3435
)
3536

3637
def load_train_objs(config):
@@ -43,7 +44,18 @@ def training(rank, world_size, config):
4344
if rank == 0:
4445
wandb.init(project="mnist_trials", config=config, save_code=True)
4546
dataset, model, optimizer = load_train_objs(config)
46-
trainer = DiscriminativeTrainer(model, dataset, config.loss_func, optimizer, rank, config.batch_size, config.save_every, config.checkpoint_folder, config.device_type)
47+
trainer = DiscriminativeTrainer(
48+
model,
49+
dataset,
50+
config.loss_func,
51+
optimizer,
52+
rank,
53+
config.batch_size,
54+
config.save_every,
55+
config.checkpoint_folder,
56+
config.device_type,
57+
config.log_wandb
58+
)
4759
trainer.train(config.total_epochs)
4860

4961
if __name__ == "__main__":

0 commit comments

Comments
 (0)