Skip to content

Commit

Permalink
updates to trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Sep 21, 2023
1 parent 8e8021e commit fe19f47
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 25 deletions.
52 changes: 30 additions & 22 deletions diffusion_models/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Callable, Literal
import wandb

class DiscriminativeTrainer:
class Trainer:
"""Trainer Class that trains 1 model instance on 1 device."""
def __init__(
self,
Expand Down Expand Up @@ -74,31 +74,18 @@ def _setup_model(self, model):
def _setup_dataloader(self, dataset):
return DataLoader(dataset, batch_size=self.batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset))

def _run_batch(self, source, targets):
self.optimizer.zero_grad()
pred = self.model(source)
loss = self.loss_func(pred, targets)
loss.backward()
self.optimizer.step()
return loss.item()

def _run_epoch_nonCuda(self, epoch):
epoch_losses = []
time1 = time()
for source, targets in self.train_data:
source, targets = source.to(self.device_type), targets.to(self.device_type)
batch_loss = self._run_batch(source, targets)
epoch_losses.append(batch_loss)
if self.log_wandb:
wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses), "epoch_time": time()-time1})
print(f"[{self.device_type}{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses)} | Time: {time()-time1:.2f}s")
def _run_batch(self, data):
raise NotImplementedError("use dedicated subclass")

def _run_epoch(self, epoch):
epoch_losses = []
time1 = time()
for source, targets in self.train_data:
source, targets = source.to(self.gpu_id), targets.to(self.gpu_id)
batch_loss = self._run_batch(source, targets)
if self.device_type == "cuda":
data = map(lambda x: x.to(self.gpu_id))
else:
data = map(lambda x: x.to(self.device_type))
batch_loss = self._run_batch(data)
epoch_losses.append(batch_loss)
if self.log_wandb:
wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses), "epoch_time": time()-time1})
Expand Down Expand Up @@ -126,4 +113,25 @@ def train(self, max_epochs: int):
else:
self._run_epoch(epoch)
if (self.gpu_id == 0) and (epoch % self.save_every == 0) and (epoch != 0):
self._save_checkpoint(epoch)
self._save_checkpoint(epoch)

class DiscriminativeTrainer(Trainer):
def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[..., Any], optimizer: Optimizer, gpu_id: int, batch_size: int, save_every: int, checkpoint_folder: str, device_type: Literal['cuda', 'mps', 'cpu'], log_wandb: bool = True) -> None:
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)

def _run_batch(self, data):
source, targets = data
self.optimizer.zero_grad()
pred = self.model(source)
loss = self.loss_func(pred, targets)
loss.backward()
self.optimizer.step()
return loss.item()

class GenerativeTrainer(Trainer):
def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[..., Any], optimizer: Optimizer, gpu_id: int, batch_size: int, save_every: int, checkpoint_folder: str, device_type: Literal['cuda', 'mps', 'cpu'], log_wandb: bool = True) -> None:
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)

def _run_batch(self, data):
self.optimizer.zero_grad()
raise NotImplementedError("not finished yet")
18 changes: 15 additions & 3 deletions tests/train_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
total_epochs = 5,
batch_size = 1000,
learning_rate = 0.001,
device_type = "cuda",
device_type = "cpu",
dataloader = MNISTTrainLoader,
architecture = MNISTEncoder,
out_classes = 10,
Expand All @@ -30,7 +30,8 @@
data_path = "/itet-stor/peerli/net_scratch",
checkpoint_folder = "/itet-stor/peerli/net_scratch/mnist_checkpoints",
save_every = 10,
loss_func = F.cross_entropy
loss_func = F.cross_entropy,
log_wandb = False
)

def load_train_objs(config):
Expand All @@ -43,7 +44,18 @@ def training(rank, world_size, config):
if rank == 0:
wandb.init(project="mnist_trials", config=config, save_code=True)
dataset, model, optimizer = load_train_objs(config)
trainer = DiscriminativeTrainer(model, dataset, config.loss_func, optimizer, rank, config.batch_size, config.save_every, config.checkpoint_folder, config.device_type)
trainer = DiscriminativeTrainer(
model,
dataset,
config.loss_func,
optimizer,
rank,
config.batch_size,
config.save_every,
config.checkpoint_folder,
config.device_type,
config.log_wandb
)
trainer.train(config.total_epochs)

if __name__ == "__main__":
Expand Down

0 comments on commit fe19f47

Please sign in to comment.