Skip to content

Commit

Permalink
added wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Sep 21, 2023
1 parent 4ce7505 commit dd9d3e3
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 82 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ docs/token.txt
# datasets
tests/data

# logging
tests/wandb

# sphinx build
docs/source/_autosummary

Expand Down
21 changes: 21 additions & 0 deletions diffusion_models/models/mnist_enc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch.nn as nn
import torch

class MNISTEncoder(nn.Module):
def __init__(self, out_classes=10, kernel_size=3):
super().__init__()
self.kernel_size = kernel_size
self.out_classes = out_classes
channels = [2**i for i in range(5)]
self.encoder = []
for i in range(4):
self.encoder.append(nn.Conv2d(channels[i], channels[i+1], kernel_size=kernel_size, padding="same"))
self.encoder.append(nn.BatchNorm2d(channels[i+1]))
self.encoder.append(nn.ReLU())
self.encoder.append(nn.MaxPool2d(2))
self.conv = nn.Sequential(*self.encoder)
self.fc = nn.Linear(16, self.out_classes)

def forward(self, x):
x = self.conv(x)
return self.fc(x.squeeze())
16 changes: 16 additions & 0 deletions diffusion_models/utils/dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Callable, Optional
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from typing import Any

class MNISTTrainLoader(MNIST):
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
download = True
super().__init__(root, train, transform, target_transform, download)

class MNISTTestLoader(MNIST):
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
download = False
super().__init__(root, train, transform, target_transform, download)
5 changes: 5 additions & 0 deletions diffusion_models/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
19 changes: 19 additions & 0 deletions diffusion_models/utils/mp_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
from typing import Any
import torch
from torch.distributed import init_process_group, destroy_process_group

class DDP_Proc_Group:
def __init__(self, function) -> None:
self.function = function

def __call__(self, *args, **kwargs) -> None:
self._ddp_setup(args[0], args[1])
self.function(*args, **kwargs)
destroy_process_group()

def _ddp_setup(self, rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
129 changes: 129 additions & 0 deletions diffusion_models/utils/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.optim import Optimizer
import numpy as np
from time import time
import wandb
from typing import Callable, Literal
import wandb

class DiscriminativeTrainer:
"""Trainer Class that trains 1 model instance on 1 device."""
def __init__(
self,
model: nn.Module,
train_data: Dataset,
loss_func: Callable,
optimizer: Optimizer,
gpu_id: int,
batch_size: int,
save_every: int,
checkpoint_folder: str,
device_type: Literal["cuda","mps","cpu"],
log_wandb: bool=True
) -> None:
"""Constructor of Trainer Class.
Parameters
----------
model
instance of nn.Module to be copied to a GPU
train_data
Dataset instance
loss_func
criterion to determine the loss
optimizer
torch.optim instance with model.parameters and learning rate passed
gpu_id
int in range [0, num_GPUs]
save_every
how often (epochs) to save model checkpoint
checkpoint_folder
where to save checkpoint to
device_type
specify in case not training no CUDA capable device
log_wandb
whether to log to wandb; requires that initialization has been done
"""
self.device_type = device_type
self.gpu_id = gpu_id
self.batch_size = batch_size
if device_type != "cuda":
self.gpu_id = 0
self.model = model.to(torch.device(device_type))
self.train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)
else:
self.model = self._setup_model(model)
self.train_data = self._setup_dataloader(train_data)
self.loss_func = loss_func
self.optimizer = optimizer
self.save_every = save_every
self.checkpoint_folder = checkpoint_folder
self.log_wandb = log_wandb
if log_wandb:
wandb.watch(self.model, log="all", log_freq=save_every)

def _setup_model(self, model):
model = model.to(self.gpu_id)
return DDP(model, device_ids=[self.gpu_id])

def _setup_dataloader(self, dataset):
return DataLoader(dataset, batch_size=self.batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset))

def _run_batch(self, source, targets):
self.optimizer.zero_grad()
pred = self.model(source)
loss = self.loss_func(pred, targets)
loss.backward()
self.optimizer.step()
return loss.item()

def _run_epoch_nonCuda(self, epoch):
epoch_losses = []
time1 = time()
for source, targets in self.train_data:
source, targets = source.to(self.device_type), targets.to(self.device_type)
batch_loss = self._run_batch(source, targets)
epoch_losses.append(batch_loss)
if self.log_wandb:
wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses)})
print(f"[{self.device_type}{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses)} | Time: {time()-time1:.2f}s")

def _run_epoch(self, epoch):
epoch_losses = []
time1 = time()
for source, targets in self.train_data:
source, targets = source.to(self.gpu_id), targets.to(self.gpu_id)
batch_loss = self._run_batch(source, targets)
epoch_losses.append(batch_loss)
if self.log_wandb:
wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses)})
print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses)} | Time: {time()-time1:.2f}s")

def _save_checkpoint(self, epoch):
ckp = self.model.state_dict()
if not os.path.isdir(self.checkpoint_folder):
os.makedirs(self.checkpoint_folder)
path = os.path.join(self.checkpoint_folder, f"checkpoint{epoch}.pt")
torch.save(ckp, path)
print(f"Epoch {epoch} | Training checkpoint saved at {path}")

def train(self, max_epochs: int):
"""Train method of Trainer class.
Parameters
----------
max_epochs
how many epochs to train the model
"""
for epoch in range(max_epochs):
if self.device_type != "cuda":
self._run_epoch_nonCuda(epoch)
else:
self._run_epoch(epoch)
if (self.gpu_id == 0) and (epoch % self.save_every == 0) and (epoch != 0):
self._save_checkpoint(epoch)
1 change: 1 addition & 0 deletions examples/vae.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Incomplete."""
import context
import torch
from diffusion_models.models.vae import VariationalAutoencoder, ResNet18Encoder, ResNetDecoderBlock
Expand Down
2 changes: 1 addition & 1 deletion tests/context.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
import sys
sys.path.append("..")
sys.path.append("../diffusion_models")
81 changes: 0 additions & 81 deletions tests/test1.ipynb

This file was deleted.

52 changes: 52 additions & 0 deletions tests/train_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import context
from torchvision.transforms import ToTensor, Compose, Normalize
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from models.mnist_enc import MNISTEncoder
import numpy as np
from time import time
from utils.trainer import DiscriminativeTrainer
import torch.multiprocessing as mp
import os
from utils.mp_setup import DDP_Proc_Group
from utils.dataloaders import MNISTTrainLoader
from utils.helpers import dotdict
import wandb
import torch.nn.functional as F

config = dotdict(
total_epochs = 100,
batch_size = 1000,
learning_rate = 0.001,
device_type = "cpu",
dataloader = MNISTTrainLoader,
architecture = MNISTEncoder,
out_classes = 10,
optimizer = torch.optim.Adam,
kernel_size = 3,
data_path = os.path.abspath("./data"),
checkpoint_folder = os.path.abspath(os.path.join("./data/checkpoints")),
save_every = 10,
loss_func = F.cross_entropy
)

def load_train_objs(config):
train_set = config.dataloader(config.data_path)
model = config.architecture(config.out_classes, config.kernel_size)
optimizer = config.optimizer(model.parameters(), lr=config.learning_rate)
return train_set, model, optimizer

def training(rank, world_size, config):
dataset, model, optimizer = load_train_objs(config)
trainer = DiscriminativeTrainer(model, dataset, config.loss_func, optimizer, rank, config.batch_size, config.save_every, config.checkpoint_folder, config.device_type)
trainer.train(config.total_epochs)

if __name__ == "__main__":
with wandb.init(project="mnist_trials", config=config, save_code=True):
if config.device_type == "cuda":
world_size = torch.cuda.device_count()
print("Device Count:", world_size)
mp.spawn(DDP_Proc_Group(training), args=(world_size, config), nprocs=world_size)
else:
training(0, 0, config)

0 comments on commit dd9d3e3

Please sign in to comment.