From 353641ac654f365ea4beae303dc9bdaa1e68a38f Mon Sep 17 00:00:00 2001 From: TySam& B Date: Sun, 12 Feb 2023 21:11:14 -0500 Subject: [PATCH] Commit changes for v0.4.0 release, basically finalized. See v0.4.0 patch notes for more details. --- main.py | 143 ++++++++++++++++++++++++-------------------------------- 1 file changed, 62 insertions(+), 81 deletions(-) diff --git a/main.py b/main.py index c26ba51..d2a2959 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ # You can run 'sed -i.bak '/\#\#/d' ./main.py' to remove the teaching comments if they are in the way of your work. <3 # This can go either way in terms of actually being helpful when it comes to execution speed. -# torch.backends.cudnn.benchmark = True +#torch.backends.cudnn.benchmark = True # This code was built from the ground up to be directly hackable and to support rapid experimentation, which is something you might see # reflected in what would otherwise seem to be odd design decisions. It also means that maybe some cleaning up is required before moving @@ -34,40 +34,41 @@ # project! :) -# This is for testing that certain changes don't exceed X% portion of the reference GPU (here an A100) +# This is for testing that certain changes don't exceed some X% portion of the reference GPU (here an A100) # so we can help reduce a possibility that future releases don't take away the accessibility of this codebase. -#torch.cuda.set_per_process_memory_fraction(fraction=8./40., device=0) ## 40. GB is the maximum memory of the base A100 GPU +#torch.cuda.set_per_process_memory_fraction(fraction=6.5/40., device=0) ## 40. GB is the maximum memory of the base A100 GPU # set global defaults (in this particular file) for convolutions default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False} -batchsize = 512 +batchsize = 1024 bias_scaler = 32 -# To replicate the ~95.77% accuracy in 188 seconds runs, simply change the base_depth from 64->128 and the num_epochs from 10->80 +# To replicate the ~95.84%-accuracy-in-172-seconds runs, you can change the base_depth from 64->128, num_epochs from 10->90, ['ema'] epochs 9->78, and cutout 0->11 hyp = { 'opt': { - 'bias_lr': 1.15 * 1.35 * 1. * bias_scaler/batchsize, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :')))) - 'non_bias_lr': 1.15 * 1.35 * 1. / batchsize, - 'bias_decay': .85 * 4.8e-4 * batchsize/bias_scaler, - 'non_bias_decay': .85 * 4.8e-4 * batchsize, + 'bias_lr': 2.1 * bias_scaler/512, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :')))) + 'non_bias_lr': 2.1 / 512, + 'bias_decay': 6.45e-4 * batchsize/bias_scaler, + 'non_bias_decay': 6.45e-4 * batchsize, 'scaling_factor': 1./10, - 'percent_start': .2, + 'percent_start': .15, + 'loss_scale_scaler': 4., # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :) }, 'net': { 'whitening': { 'kernel_size': 2, 'num_examples': 50000, }, - 'batch_norm_momentum': .8, + 'batch_norm_momentum': .8, # Equivalent roughly to updating entirely every step, as momentum for batchnorm is represented in a different way (1 - momentum) due to a quirk of the original paper... ;'(((( 'cutout_size': 0, - 'pad_amount': 3, + 'pad_amount': 2, 'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly }, 'misc': { 'ema': { - 'epochs': 2, - 'decay_base': .986, - 'every_n_steps': 2, + 'epochs': 9, + 'decay_base': .98, + 'every_n_steps': 5, }, 'train_epochs': 10, 'device': 'cuda', @@ -80,10 +81,6 @@ ############################################# if not os.path.exists(hyp['misc']['data_location']): - cifar10_mean, cifar10_std = [ - torch.tensor([0.4913997551666284, 0.48215855929893703, 0.4465309133731618], device=hyp['misc']['device']), - torch.tensor([0.24703225141799082, 0.24348516474564, 0.26158783926049628], device=hyp['misc']['device']) - ] transform = transforms.Compose([ transforms.ToTensor()]) @@ -103,6 +100,8 @@ train_dataset_gpu['images'], train_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(train_dataset_gpu_loader))] eval_dataset_gpu['images'], eval_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(eval_dataset_gpu_loader)) ] + cifar10_std, cifar10_mean = torch.std_mean(train_dataset_gpu['images'], dim=(0, 2, 3)) # dynamically calculate the std and mean from the data. this shortens the code and should help us adapt to new datasets! + def batch_normalize_images(input_images, mean, std): return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1) @@ -161,30 +160,20 @@ def __init__(self, *args, **kwargs): # can hack any changes to each residual group that you want directly in here class ConvGroup(nn.Module): - def __init__(self, channels_in, channels_out, residual, short, pool, se): + def __init__(self, channels_in, channels_out, pool): super().__init__() - self.short = short - self.pool = pool # todo: we can condense this later - self.se = se + self.pool = pool # todo: maybe we can condense this later - self.residual = residual self.channels_in = channels_in self.channels_out = channels_out - self.conv1 = Conv(channels_in, channels_out) self.pool1 = nn.MaxPool2d(2) + self.conv1 = Conv(channels_in, channels_out) + self.conv2 = Conv(channels_out, channels_out) self.norm1 = BatchNorm(channels_out) - self.activ = nn.GELU() + self.norm2 = BatchNorm(channels_out) + self.activ = nn.GELU() - # note: this has to be flat if we're jitting things.... we just might burn a bit of extra GPU mem if so - if not short: - self.conv2 = Conv(channels_out, channels_out) - self.conv3 = Conv(channels_out, channels_out) - self.norm2 = BatchNorm(channels_out) - self.norm3 = BatchNorm(channels_out) - - self.se1 = nn.Linear(channels_out, channels_out//16) - self.se2 = nn.Linear(channels_out//16, channels_out) def forward(self, x): x = self.conv1(x) @@ -192,25 +181,14 @@ def forward(self, x): x = self.pool1(x) x = self.norm1(x) x = self.activ(x) - if self.short: # layer 2 doesn't necessarily need the residual, so we just return it. - return x residual = x - if self.se: - mult = torch.sigmoid(self.se2(self.activ(self.se1(torch.mean(residual, dim=(2,3)))))).unsqueeze(-1).unsqueeze(-1) x = self.conv2(x) x = self.norm2(x) x = self.activ(x) - x = self.conv3(x) - if self.se: - x = x * mult - - x = self.norm3(x) - x = self.activ(x) x = x + residual # haiku return x -# Set to 1 for now just to debug a few things.... class TemperatureScaler(nn.Module): def __init__(self, init_val): super().__init__() @@ -237,12 +215,17 @@ def forward(self, x): ############################################# def get_patches(x, patch_shape=(3, 3), dtype=torch.float32): - # TODO: Annotate + # This uses the unfold operation (https://pytorch.org/docs/stable/generated/torch.nn.functional.unfold.html?highlight=unfold#torch.nn.functional.unfold) + # to extract a _view_ (i.e., there's no data copied here) of blocks in the input tensor. We have to do it twice -- once horizontally, once vertically. Then + # from that, we get our kernel_size*kernel_size patches to later calculate the statistics for the whitening tensor on :D c, (h, w) = x.shape[1], patch_shape return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).to(dtype) # TODO: Annotate? def get_whitening_parameters(patches): - # TODO: Let's annotate this, please! :'D / D': + # As a high-level summary, we're basically finding the high-dimensional oval that best fits the data here. + # We can then later use this information to map the input information to a nicely distributed sphere, where also + # the most significant features of the inputs each have their own axis. This significantly cleans things up for the + # rest of the neural network and speeds up training. n,c,h,w = patches.shape est_covariance = torch.cov(patches.view(n, c*h*w).t()) eigenvalues, eigenvectors = torch.linalg.eigh(est_covariance, UPLO='U') # this is the same as saying we want our eigenvectors, with the specification that the matrix be an upper triangular matrix (instead of a lower-triangular matrix) @@ -255,27 +238,29 @@ def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block previous_block_data = train_set[:num_examples,:,pad_amount:-pad_amount,pad_amount:-pad_amount] # if it's none, we're at the beginning of our network. else: previous_block_data = train_set[:num_examples,:,:,:] + + # chunking code to save memory for smaller-memory-size (generally consumer) GPUs if whiten_splits is None: - previous_block_data_split = [previous_block_data] # list of length 1 so we can reuse the splitting code down below + previous_block_data_split = [previous_block_data] # If we're whitening in one go, then put it in a list for simplicity to reuse the logic below else: - previous_block_data_split = previous_block_data.split(whiten_splits, dim=0) + previous_block_data_split = previous_block_data.split(whiten_splits, dim=0) # Otherwise, we split this into different chunks to keep things manageable eigenvalue_list, eigenvector_list = [], [] for data_split in previous_block_data_split: - eigenvalues, eigenvectors = get_whitening_parameters(get_patches(data_split, patch_shape=layer.weight.data.shape[2:])) # center crop to remove padding + eigenvalues, eigenvectors = get_whitening_parameters(get_patches(data_split, patch_shape=layer.weight.data.shape[2:])) eigenvalue_list.append(eigenvalues) eigenvector_list.append(eigenvectors) eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0) eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0) - # for some reason, the eigenvalues and eigenvectors seem to come out all in float32 for this? ! ?! ?!?!?!? :'((((