-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
247 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,9 @@ docs/token.txt | |
# datasets | ||
tests/data | ||
|
||
# logging | ||
tests/wandb | ||
|
||
# sphinx build | ||
docs/source/_autosummary | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch.nn as nn | ||
import torch | ||
|
||
class MNISTEncoder(nn.Module): | ||
def __init__(self, out_classes=10, kernel_size=3): | ||
super().__init__() | ||
self.kernel_size = kernel_size | ||
self.out_classes = out_classes | ||
channels = [2**i for i in range(5)] | ||
self.encoder = [] | ||
for i in range(4): | ||
self.encoder.append(nn.Conv2d(channels[i], channels[i+1], kernel_size=kernel_size, padding="same")) | ||
self.encoder.append(nn.BatchNorm2d(channels[i+1])) | ||
self.encoder.append(nn.ReLU()) | ||
self.encoder.append(nn.MaxPool2d(2)) | ||
self.conv = nn.Sequential(*self.encoder) | ||
self.fc = nn.Linear(16, self.out_classes) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return self.fc(x.squeeze()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from typing import Callable, Optional | ||
from torchvision.datasets import MNIST | ||
from torchvision.transforms import Compose, ToTensor, Normalize | ||
from typing import Any | ||
|
||
class MNISTTrainLoader(MNIST): | ||
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None: | ||
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) | ||
download = True | ||
super().__init__(root, train, transform, target_transform, download) | ||
|
||
class MNISTTestLoader(MNIST): | ||
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None: | ||
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) | ||
download = False | ||
super().__init__(root, train, transform, target_transform, download) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
class dotdict(dict): | ||
"""dot.notation access to dictionary attributes""" | ||
__getattr__ = dict.get | ||
__setattr__ = dict.__setitem__ | ||
__delattr__ = dict.__delitem__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
from typing import Any | ||
import torch | ||
from torch.distributed import init_process_group, destroy_process_group | ||
|
||
class DDP_Proc_Group: | ||
def __init__(self, function) -> None: | ||
self.function = function | ||
|
||
def __call__(self, *args, **kwargs) -> None: | ||
self._ddp_setup(args[0], args[1]) | ||
self.function(*args, **kwargs) | ||
destroy_process_group() | ||
|
||
def _ddp_setup(self, rank, world_size): | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
torch.cuda.set_device(rank) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import Dataset, DataLoader | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.utils.data.distributed import DistributedSampler | ||
from torch.optim import Optimizer | ||
import numpy as np | ||
from time import time | ||
import wandb | ||
from typing import Callable, Literal | ||
import wandb | ||
|
||
class DiscriminativeTrainer: | ||
"""Trainer Class that trains 1 model instance on 1 device.""" | ||
def __init__( | ||
self, | ||
model: nn.Module, | ||
train_data: Dataset, | ||
loss_func: Callable, | ||
optimizer: Optimizer, | ||
gpu_id: int, | ||
batch_size: int, | ||
save_every: int, | ||
checkpoint_folder: str, | ||
device_type: Literal["cuda","mps","cpu"], | ||
log_wandb: bool=True | ||
) -> None: | ||
"""Constructor of Trainer Class. | ||
Parameters | ||
---------- | ||
model | ||
instance of nn.Module to be copied to a GPU | ||
train_data | ||
Dataset instance | ||
loss_func | ||
criterion to determine the loss | ||
optimizer | ||
torch.optim instance with model.parameters and learning rate passed | ||
gpu_id | ||
int in range [0, num_GPUs] | ||
save_every | ||
how often (epochs) to save model checkpoint | ||
checkpoint_folder | ||
where to save checkpoint to | ||
device_type | ||
specify in case not training no CUDA capable device | ||
log_wandb | ||
whether to log to wandb; requires that initialization has been done | ||
""" | ||
self.device_type = device_type | ||
self.gpu_id = gpu_id | ||
self.batch_size = batch_size | ||
if device_type != "cuda": | ||
self.gpu_id = 0 | ||
self.model = model.to(torch.device(device_type)) | ||
self.train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True) | ||
else: | ||
self.model = self._setup_model(model) | ||
self.train_data = self._setup_dataloader(train_data) | ||
self.loss_func = loss_func | ||
self.optimizer = optimizer | ||
self.save_every = save_every | ||
self.checkpoint_folder = checkpoint_folder | ||
self.log_wandb = log_wandb | ||
if log_wandb: | ||
wandb.watch(self.model, log="all", log_freq=save_every) | ||
|
||
def _setup_model(self, model): | ||
model = model.to(self.gpu_id) | ||
return DDP(model, device_ids=[self.gpu_id]) | ||
|
||
def _setup_dataloader(self, dataset): | ||
return DataLoader(dataset, batch_size=self.batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset)) | ||
|
||
def _run_batch(self, source, targets): | ||
self.optimizer.zero_grad() | ||
pred = self.model(source) | ||
loss = self.loss_func(pred, targets) | ||
loss.backward() | ||
self.optimizer.step() | ||
return loss.item() | ||
|
||
def _run_epoch_nonCuda(self, epoch): | ||
epoch_losses = [] | ||
time1 = time() | ||
for source, targets in self.train_data: | ||
source, targets = source.to(self.device_type), targets.to(self.device_type) | ||
batch_loss = self._run_batch(source, targets) | ||
epoch_losses.append(batch_loss) | ||
if self.log_wandb: | ||
wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses)}) | ||
print(f"[{self.device_type}{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses)} | Time: {time()-time1:.2f}s") | ||
|
||
def _run_epoch(self, epoch): | ||
epoch_losses = [] | ||
time1 = time() | ||
for source, targets in self.train_data: | ||
source, targets = source.to(self.gpu_id), targets.to(self.gpu_id) | ||
batch_loss = self._run_batch(source, targets) | ||
epoch_losses.append(batch_loss) | ||
if self.log_wandb: | ||
wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses)}) | ||
print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses)} | Time: {time()-time1:.2f}s") | ||
|
||
def _save_checkpoint(self, epoch): | ||
ckp = self.model.state_dict() | ||
if not os.path.isdir(self.checkpoint_folder): | ||
os.makedirs(self.checkpoint_folder) | ||
path = os.path.join(self.checkpoint_folder, f"checkpoint{epoch}.pt") | ||
torch.save(ckp, path) | ||
print(f"Epoch {epoch} | Training checkpoint saved at {path}") | ||
|
||
def train(self, max_epochs: int): | ||
"""Train method of Trainer class. | ||
Parameters | ||
---------- | ||
max_epochs | ||
how many epochs to train the model | ||
""" | ||
for epoch in range(max_epochs): | ||
if self.device_type != "cuda": | ||
self._run_epoch_nonCuda(epoch) | ||
else: | ||
self._run_epoch(epoch) | ||
if (self.gpu_id == 0) and (epoch % self.save_every == 0) and (epoch != 0): | ||
self._save_checkpoint(epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
import sys | ||
sys.path.append("..") | ||
sys.path.append("../diffusion_models") |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import context | ||
from torchvision.transforms import ToTensor, Compose, Normalize | ||
from torch.utils.data import DataLoader | ||
import torch | ||
import torch.nn as nn | ||
from models.mnist_enc import MNISTEncoder | ||
import numpy as np | ||
from time import time | ||
from utils.trainer import DiscriminativeTrainer | ||
import torch.multiprocessing as mp | ||
import os | ||
from utils.mp_setup import DDP_Proc_Group | ||
from utils.dataloaders import MNISTTrainLoader | ||
from utils.helpers import dotdict | ||
import wandb | ||
import torch.nn.functional as F | ||
|
||
config = dotdict( | ||
total_epochs = 100, | ||
batch_size = 1000, | ||
learning_rate = 0.001, | ||
device_type = "cpu", | ||
dataloader = MNISTTrainLoader, | ||
architecture = MNISTEncoder, | ||
out_classes = 10, | ||
optimizer = torch.optim.Adam, | ||
kernel_size = 3, | ||
data_path = os.path.abspath("./data"), | ||
checkpoint_folder = os.path.abspath(os.path.join("./data/checkpoints")), | ||
save_every = 10, | ||
loss_func = F.cross_entropy | ||
) | ||
|
||
def load_train_objs(config): | ||
train_set = config.dataloader(config.data_path) | ||
model = config.architecture(config.out_classes, config.kernel_size) | ||
optimizer = config.optimizer(model.parameters(), lr=config.learning_rate) | ||
return train_set, model, optimizer | ||
|
||
def training(rank, world_size, config): | ||
dataset, model, optimizer = load_train_objs(config) | ||
trainer = DiscriminativeTrainer(model, dataset, config.loss_func, optimizer, rank, config.batch_size, config.save_every, config.checkpoint_folder, config.device_type) | ||
trainer.train(config.total_epochs) | ||
|
||
if __name__ == "__main__": | ||
with wandb.init(project="mnist_trials", config=config, save_code=True): | ||
if config.device_type == "cuda": | ||
world_size = torch.cuda.device_count() | ||
print("Device Count:", world_size) | ||
mp.spawn(DDP_Proc_Group(training), args=(world_size, config), nprocs=world_size) | ||
else: | ||
training(0, 0, config) |