diff --git a/README.md b/README.md index 1fdb7c9..2ced1a0 100644 --- a/README.md +++ b/README.md @@ -21,15 +21,18 @@ Goals: * hackable * few external dependencies (currently only torch and torchvision) * ~world-record single-GPU training time (this repo holds the current world record at ~<12.38 seconds on an A100, down from ~18.1 seconds originally). -* <2 seconds training time in <2 years +* <2 seconds training time in <2 years (yep!) -This is a neural network implementation that started from a painstaking reproduction from nearly the ground-up a hacking-friendly version of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/). This repository is meant to function primarily as a very human-friendly researcher's toolbench first, a benchmark a close second (ironically currently holding the world record), and a learning codebase third. We're now in the stage where the real fun begins -- the journey to <2 seconds. Some of the early progress was surprisingly easy, but it will likely get pretty crazy as we get closer and closer to our goal. +This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10 (under 10 seconds on an A100!!!) -This code took about 120-130 hours of work during the initial write from start to finish, about 80-90+ of which were mind-numbingly tedious debugging. Some strange things seem to really matter for performance (speed and accuracy), and some strangely do not seem to. To that end, I found it very educational to create (and may do a writeup someday if enough people and I have enough interest in it). - - -I built this because I loved David's work but his code was difficult for my quick-experiment-and-hacking usecases. This code is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. It is excellent for rapid idea exploring -- almost everywhere in the pipeline is exposed and built to not be user-hostile. I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, so your support through word of mouth or otherwise is especially encouraged. +What we've added: +* squeeze and excite layers +* way too much hyperparameter tuning +* miscellaneous architecture trimmings (see the patch notes) +* memory format changes (and more!) to better use tensor cores/etc +* and more! +This code, in comparison to David's original code, is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. It is excellent for rapid idea exploring -- almost everywhere in the pipeline is exposed and built to be user-friendly. I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, so your support is encouraged. Share this repo with someone you know that might like it! Your support helps a lot -- even if it's a dollar as month. I have several more projects I'm in various stages on, and you can help me have the money and time to get this project (and the others) to the finish line! If you like what I'm doing, or this project has brought you some value, please consider subscribing on my [Patreon](https://www.patreon.com/user/posts?u=83632131). There's not too many extra rewards besides better software more frequently. Alternatively, if you want me to work up to a part-time amount of hours with you, feel free to reach out to me at hire.tysam@gmail.com. I'd love to hear from you. diff --git a/main.py b/main.py index 20423c8..faef86f 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,6 @@ -# BUG: Currently, I haven't found a good way to distinguish between python and ipython, so we're leaving the 'colab mode' to require a manual -# uncomment of this code to fully guard against state errors. This lets people run the one-line download-and-train command appropriately. -""" +# Note: The one change we need to make if we're in Colab is to uncomment this below block. # If we are in an ipython session or a notebook, clear the state to avoid bugs +""" try: _ = get_ipython().__class__.__name__ ## we set -f below to avoid prompting the user before clearing the notebook state @@ -34,33 +33,40 @@ # ways this code could be improved and cleaned up, please do open a PR on the GitHub repo. Your support and help is much appreciated for this # project! :) + +# This is for testing that certain changes don't exceed X% portion of the reference GPU (here an A100) +# so we can help reduce a possibility that future releases don't take away the accessibility of this codebase. +#torch.cuda.set_per_process_memory_fraction(fraction=8./40., device=0) ## 40. GB is the maximum memory of the base A100 GPU + # set global defaults (in this particular file) for convolutions default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False} batchsize = 512 -bias_scaler = 64 +bias_scaler = 32 +# To replicate the ~95.77% accuracy in 188 seconds runs, simply change the base_depth from 64->128 and the num_epochs from 10->80 hyp = { 'opt': { - 'bias_lr': 1.35 * 1. * bias_scaler/batchsize, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :')))) - 'non_bias_lr': 1.35 * 1. / batchsize, - 'bias_decay': 4.8e-4 * batchsize/bias_scaler, - 'non_bias_decay': 4.8e-4 * batchsize, - 'scaling_factor': 1./16, + 'bias_lr': 1.15 * 1.35 * 1. * bias_scaler/batchsize, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :')))) + 'non_bias_lr': 1.15 * 1.35 * 1. / batchsize, + 'bias_decay': .85 * 4.8e-4 * batchsize/bias_scaler, + 'non_bias_decay': .85 * 4.8e-4 * batchsize, + 'scaling_factor': 1./10, 'percent_start': .2, }, 'net': { 'whitening': { - 'kernel_size': 3, - 'num_examples': 10000, + 'kernel_size': 2, + 'num_examples': 50000, }, - 'batch_norm_momentum': .4, + 'batch_norm_momentum': .8, 'cutout_size': 0, - 'pad_amount': 4, + 'pad_amount': 3, + 'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly }, 'misc': { 'ema': { - 'epochs': 3, - 'decay_base': .987, + 'epochs': 2, + 'decay_base': .986, 'every_n_steps': 2, }, 'train_epochs': 10, @@ -155,10 +161,11 @@ def __init__(self, *args, **kwargs): # can hack any changes to each residual group that you want directly in here class ConvGroup(nn.Module): - def __init__(self, channels_in, channels_out, residual, short, pool): + def __init__(self, channels_in, channels_out, residual, short, pool, se): super().__init__() self.short = short self.pool = pool # todo: we can condense this later + self.se = se self.residual = residual self.channels_in = channels_in @@ -167,7 +174,7 @@ def __init__(self, channels_in, channels_out, residual, short, pool): self.conv1 = Conv(channels_in, channels_out) self.pool1 = nn.MaxPool2d(2) self.norm1 = BatchNorm(channels_out) - self.activ = nn.CELU(alpha=.3) + self.activ = nn.GELU() # note: this has to be flat if we're jitting things.... we just might burn a bit of extra GPU mem if so if not short: @@ -176,6 +183,9 @@ def __init__(self, channels_in, channels_out, residual, short, pool): self.norm2 = BatchNorm(channels_out) self.norm3 = BatchNorm(channels_out) + self.se1 = nn.Linear(channels_out, channels_out//16) + self.se2 = nn.Linear(channels_out//16, channels_out) + def forward(self, x): x = self.conv1(x) if self.pool: @@ -185,13 +195,18 @@ def forward(self, x): if self.short: # layer 2 doesn't necessarily need the residual, so we just return it. return x residual = x + if self.se: + mult = torch.sigmoid(self.se2(self.activ(self.se1(torch.mean(residual, dim=(2,3)))))).unsqueeze(-1).unsqueeze(-1) x = self.conv2(x) x = self.norm2(x) x = self.activ(x) x = self.conv3(x) + if self.se: + x = x * mult + x = self.norm3(x) x = self.activ(x) - x = x + residual # haiku + x = x + residual # haiku return x @@ -234,25 +249,37 @@ def get_whitening_parameters(patches): return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.t().reshape(c*h*w,c,h,w).flip(0) # Run this over the training set to calculate the patch statistics, then set the initial convolution as a non-learnable 'whitening' layer -def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block_data=None, pad_amount=None): +def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block_data=None, pad_amount=None, freeze=True, whiten_splits=None): if train_set is not None and previous_block_data is None: if pad_amount > 0: previous_block_data = train_set[:num_examples,:,pad_amount:-pad_amount,pad_amount:-pad_amount] # if it's none, we're at the beginning of our network. else: previous_block_data = train_set[:num_examples,:,:,:] + if whiten_splits is None: + previous_block_data_split = [previous_block_data] # list of length 1 so we can reuse the splitting code down below + else: + previous_block_data_split = previous_block_data.split(whiten_splits, dim=0) + + eigenvalue_list, eigenvector_list = [], [] + for data_split in previous_block_data_split: + eigenvalues, eigenvectors = get_whitening_parameters(get_patches(data_split, patch_shape=layer.weight.data.shape[2:])) # center crop to remove padding + eigenvalue_list.append(eigenvalues) + eigenvector_list.append(eigenvectors) - eigenvalues, eigenvectors = get_whitening_parameters(get_patches(previous_block_data, patch_shape=layer.weight.data.shape[2:])) # center crop to remove padding + eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0) + eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0) # for some reason, the eigenvalues and eigenvectors seem to come out all in float32 for this? ! ?! ?!?!?!? :'(((( = ema_epoch_start and current_steps % hyp['misc']['ema']['every_n_steps'] == 0: ## Initialize the ema from the network at this point in time if it does not already exist.... :D - if net_ema is None: + if net_ema is None or epoch_step < num_cooldown_before_freeze_steps: # don't snapshot the network yet if so! net_ema = NetworkEMA(net, decay=projected_ema_decay_val) + continue net_ema.update(net) ender.record() torch.cuda.synchronize() @@ -597,7 +633,10 @@ def main(): # Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....) ## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round. print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch == hyp['misc']['train_epochs'] - 1)) + return ema_val_acc # Return the final ema accuracy achieved (not using the 'best accuracy' selection strategy, which I think is okay here....) if __name__ == "__main__": + acc_list = [] for run_num in range(25): - main() + acc_list.append(torch.tensor(main())) + print("Mean and variance:", (torch.mean(torch.stack(acc_list)), torch.var(torch.stack(acc_list)).item()).item())