diff --git a/README.md b/README.md index 78f0b65..bf43f70 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Goals: * torch- and python-idiomatic * hackable * few external dependencies (currently only torch and torchvision) -* ~world-record single-GPU training time (this repo holds the current world record at ~<7 (!!!) seconds on an A100, down from ~18.1 seconds originally). +* ~world-record single-GPU training time (this repo holds the current world record at ~<6.3 (!!!) seconds on an A100, down from ~18.1 seconds originally). * <2 seconds training time in <2 years (yep!) This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10, for example. @@ -39,6 +39,9 @@ What we've added: * dirac initializations on non-depth-transitional layers (information passthrough on init) * and more! +What we've removed: +* explicit residual layers. yep. + This code, in comparison to David's original code, is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. It is excellent for rapid idea exploring -- almost everywhere in the pipeline is exposed and built to be user-friendly. I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, so your support is encouraged. Share this repo with someone you know that might like it! Feel free to check out my[Patreon](https://www.patreon.com/user/posts?u=83632131) if you like what I'm doing here and want more!. Additionally, if you want me to work up to a part-time amount of hours with you, feel free to reach out to me at hire.tysam@gmail.com. I'd love to hear from you. diff --git a/main.py b/main.py index 1e5bb12..82976a5 100644 --- a/main.py +++ b/main.py @@ -43,25 +43,24 @@ default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False} batchsize = 1024 -bias_scaler = 56 -# To replicate the ~95.78%-accuracy-in-113-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->85, ['ema'] epochs 10->75, cutmix_size 3->9, and cutmix_epochs 6->75 +bias_scaler = 64 +# To replicate the ~95.79%-accuracy-in-110-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->90, ['ema'] epochs 10->80, cutmix_size 3->10, and cutmix_epochs 6->80 hyp = { 'opt': { - 'bias_lr': 1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :')))) - 'non_bias_lr': 1.64 / 512, - 'bias_decay': 1.08 * 6.45e-4 * batchsize/bias_scaler, - 'non_bias_decay': 1.08 * 6.45e-4 * batchsize, + 'bias_lr': 1.525 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :')))) + 'non_bias_lr': 1.525 / 512, + 'bias_decay': 6.687e-4 * batchsize/bias_scaler, + 'non_bias_decay': 6.687e-4 * batchsize, 'scaling_factor': 1./9, 'percent_start': .23, - 'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :) + 'loss_scale_scaler': 1./32, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :) }, 'net': { 'whitening': { 'kernel_size': 2, 'num_examples': 50000, }, - 'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( ) - 'conv_norm_pow': 2.6, + 'batch_norm_momentum': .4, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( ) 'cutmix_size': 3, 'cutmix_epochs': 6, 'pad_amount': 2, @@ -162,42 +161,34 @@ def __init__(self, num_features, eps=1e-12, momentum=hyp['net']['batch_norm_mome # Having an outer class like this does add space and complexity but offers us # a ton of freedom when it comes to hacking in unique functionality for each layer type class Conv(nn.Conv2d): - def __init__(self, *args, norm=False, **kwargs): + def __init__(self, *args, **kwargs): kwargs = {**default_conv_kwargs, **kwargs} super().__init__(*args, **kwargs) self.kwargs = kwargs - self.norm = norm - - def forward(self, x): - if self.training and self.norm: - # TODO: Do/should we always normalize along dimension 1 of the weight vector(s), or the height x width dims too? - with torch.no_grad(): - F.normalize(self.weight.data, p=self.norm) - return super().forward(x) class Linear(nn.Linear): - def __init__(self, *args, norm=False, **kwargs): + def __init__(self, *args, temperature=None, **kwargs): super().__init__(*args, **kwargs) self.kwargs = kwargs - self.norm = norm + self.temperature = temperature def forward(self, x): - if self.training and self.norm: - # TODO: Normalize on dim 1 or dim 0 for this guy? - with torch.no_grad(): - F.normalize(self.weight.data, p=self.norm) - return super().forward(x) + if self.temperature is not None: + weight = self.weight * self.temperature + else: + weight = self.weight + return x @ weight.T -# can hack any changes to each residual group that you want directly in here +# can hack any changes to each convolution group that you want directly in here class ConvGroup(nn.Module): - def __init__(self, channels_in, channels_out, norm): + def __init__(self, channels_in, channels_out): super().__init__() - self.channels_in = channels_in + self.channels_in = channels_in self.channels_out = channels_out self.pool1 = nn.MaxPool2d(2) - self.conv1 = Conv(channels_in, channels_out, norm=norm) - self.conv2 = Conv(channels_out, channels_out, norm=norm) + self.conv1 = Conv(channels_in, channels_out) + self.conv2 = Conv(channels_out, channels_out) self.norm1 = BatchNorm(channels_out) self.norm2 = BatchNorm(channels_out) @@ -210,20 +201,11 @@ def forward(self, x): x = self.pool1(x) x = self.norm1(x) x = self.activ(x) - residual = x x = self.conv2(x) x = self.norm2(x) x = self.activ(x) - x = x + residual # haiku - return x -class TemperatureScaler(nn.Module): - def __init__(self, init_val): - super().__init__() - self.scaler = torch.tensor(init_val) - - def forward(self, x): - return x.mul(self.scaler) + return x class FastGlobalMaxPooling(nn.Module): def __init__(self): @@ -275,7 +257,7 @@ def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block eigenvalue_list.append(eigenvalues) eigenvector_list.append(eigenvectors) - eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0) + eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0) eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0) # i believe the eigenvalues and eigenvectors come out in float32 for this because we implicitly cast it to float32 in the patches function (for numerical stability) set_whitening_conv(layer, eigenvalues.to(dtype=layer.weight.dtype), eigenvectors.to(dtype=layer.weight.dtype), freeze=freeze) @@ -284,7 +266,8 @@ def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=True): shape = conv_layer.weight.data.shape - conv_layer.weight.data[-eigenvectors.shape[0]:, :, :, :] = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :] # set the first n filters of the weight data to the top n significant (sorted by importance) filters from the eigenvectors + eigenvectors_sliced = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :] # set the first n filters of the weight data to the top n significant (sorted by importance) filters from the eigenvectors + conv_layer.weight.data = torch.cat((eigenvectors_sliced, -eigenvectors_sliced), dim=0) ## We don't want to train this, since this is implicitly whitening over the whole dataset ## For more info, see David Page's original blogposts (link in the README.md as of this commit.) if freeze: @@ -304,7 +287,7 @@ def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=T 'num_classes': 10 } -class SpeedyResNet(nn.Module): +class SpeedyConvNet(nn.Module): def __init__(self, network_dict): super().__init__() self.net_dict = network_dict # flexible, defined in the make_net function @@ -314,14 +297,12 @@ def forward(self, x): if not self.training: x = torch.cat((x, torch.flip(x, (-1,)))) x = self.net_dict['initial_block']['whiten'](x) - x = self.net_dict['initial_block']['project'](x) x = self.net_dict['initial_block']['activation'](x) - x = self.net_dict['residual1'](x) - x = self.net_dict['residual2'](x) - x = self.net_dict['residual3'](x) + x = self.net_dict['conv_group_1'](x) + x = self.net_dict['conv_group_2'](x) + x = self.net_dict['conv_group_3'](x) x = self.net_dict['pooling'](x) x = self.net_dict['linear'](x) - x = self.net_dict['temperature'](x) if not self.training: # Average the predictions from the lr-flipped inputs during eval orig, flipped = x.split(x.shape[0]//2, dim=0) @@ -335,18 +316,16 @@ def make_net(): network_dict = nn.ModuleDict({ 'initial_block': nn.ModuleDict({ 'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0), - 'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1, norm=2.2), # The norm argument means we renormalize the weights to be length 1 for this as the power for the norm, each step 'activation': nn.GELU(), }), - 'residual1': ConvGroup(depths['init'], depths['block1'], hyp['net']['conv_norm_pow']), - 'residual2': ConvGroup(depths['block1'], depths['block2'], hyp['net']['conv_norm_pow']), - 'residual3': ConvGroup(depths['block2'], depths['block3'], hyp['net']['conv_norm_pow']), + 'conv_group_1': ConvGroup(2*whiten_conv_depth, depths['block1']), + 'conv_group_2': ConvGroup(depths['block1'], depths['block2']), + 'conv_group_3': ConvGroup(depths['block2'], depths['block3']), 'pooling': FastGlobalMaxPooling(), - 'linear': Linear(depths['block3'], depths['num_classes'], bias=False, norm=5.), - 'temperature': TemperatureScaler(hyp['opt']['scaling_factor']) + 'linear': Linear(depths['block3'], depths['num_classes'], bias=False, temperature=hyp['opt']['scaling_factor']), }) - net = SpeedyResNet(network_dict) + net = SpeedyConvNet(network_dict) net = net.to(hyp['misc']['device']) net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training net.train() @@ -365,18 +344,35 @@ def make_net(): ## the index lookup in the dataloader may give you some trouble depending ## upon exactly how memory-limited you are - ## We initialize the projections layer to return exactly the spatial inputs, this way we start - ## at a nice clean place (the whitened image in feature space, directly) and can iterate directly from there. - torch.nn.init.dirac_(net.net_dict['initial_block']['project'].weight) for layer_name in net.net_dict.keys(): - if 'residual' in layer_name: - ## We do the same for the second layer in each residual block, since this only + if 'conv_group' in layer_name: + # Create an implicit residual via a dirac-initialized tensor + dirac_weights_in = torch.nn.init.dirac_(torch.empty_like(net.net_dict[layer_name].conv1.weight)) + + # Add the implicit residual to the already-initialized convolutional transition layer. + # One can use more sophisticated initializations, but this one appeared worked best in testing. + # What this does is brings up the features from the previous residual block virtually, so not only + # do we have residual information flow within each block, we have a nearly direct connection from + # the early layers of the network to the loss function. + std_pre, mean_pre = torch.std_mean(net.net_dict[layer_name].conv1.weight.data) + net.net_dict[layer_name].conv1.weight.data = net.net_dict[layer_name].conv1.weight.data + dirac_weights_in + std_post, mean_post = torch.std_mean(net.net_dict[layer_name].conv1.weight.data) + + # Renormalize the weights to match the original initialization statistics + net.net_dict[layer_name].conv1.weight.data.sub_(mean_post).div_(std_post).mul_(std_pre).add_(mean_pre) + + ## We do the same for the second layer in each convolution group block, since this only ## adds a simple multiplier to the inputs instead of the noise of a randomly-initialized ## convolution. This can be easily scaled down by the network, and the weights can more easily ## pivot in whichever direction they need to go now. + ## The reason that I believe that this works so well is because a combination of MaxPool2d + ## and the nn.GeLU function's positive bias encouraging values towards the nearly-linear + ## region of the GeLU activation function at network initialization. I am not currently + ## sure about this, however, it will require some more investigation. For now -- it works! D: torch.nn.init.dirac_(net.net_dict[layer_name].conv2.weight) + return net #############################################