diff --git a/README.md b/README.md index 853891b..a064b94 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,10 @@ Goals: * torch- and python-idiomatic * hackable * few external dependencies (currently only torch and torchvision) -* ~world-record single-GPU training time (this repo holds the current world record at ~<10 seconds on an A100, down from ~18.1 seconds originally). +* ~world-record single-GPU training time (this repo holds the current world record at ~<7 (!!) seconds on an A100, down from ~18.1 seconds originally). * <2 seconds training time in <2 years (yep!) -This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10 (under 10 seconds on an A100!!!) +This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10, though it will likely get to be _much_ harder for us to continue to improve our speeds significantly from here on out. What we've added: * squeeze and excite layers diff --git a/main.py b/main.py index d2a2959..7a55a92 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ """ import functools from functools import partial +import math import os import copy @@ -42,35 +43,37 @@ default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False} batchsize = 1024 -bias_scaler = 32 -# To replicate the ~95.84%-accuracy-in-172-seconds runs, you can change the base_depth from 64->128, num_epochs from 10->90, ['ema'] epochs 9->78, and cutout 0->11 +bias_scaler = 48 +# To replicate the ~95.80%-accuracy-in-120-seconds runs, you can change the base_depth from 64->128, num_epochs from 10->90, ['ema'] epochs 9->78, and cutmix_size 0->9 hyp = { 'opt': { - 'bias_lr': 2.1 * bias_scaler/512, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :')))) - 'non_bias_lr': 2.1 / 512, - 'bias_decay': 6.45e-4 * batchsize/bias_scaler, - 'non_bias_decay': 6.45e-4 * batchsize, - 'scaling_factor': 1./10, - 'percent_start': .15, - 'loss_scale_scaler': 4., # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :) + 'bias_lr': 1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :')))) + 'non_bias_lr': 1.64 / 512, + 'bias_decay': 1.05 * 6.45e-4 * batchsize/bias_scaler, + 'non_bias_decay': 1.05 * 6.45e-4 * batchsize, + 'scaling_factor': 1./9, + 'percent_start': .23, + 'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :) }, 'net': { 'whitening': { 'kernel_size': 2, 'num_examples': 50000, }, - 'batch_norm_momentum': .8, # Equivalent roughly to updating entirely every step, as momentum for batchnorm is represented in a different way (1 - momentum) due to a quirk of the original paper... ;'(((( - 'cutout_size': 0, + 'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( ) + 'cutmix_size': 3, + 'cutmix_epochs': 5, 'pad_amount': 2, 'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly }, 'misc': { 'ema': { - 'epochs': 9, - 'decay_base': .98, + 'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too + 'decay_base': .95, + 'decay_pow': 3, 'every_n_steps': 5, }, - 'train_epochs': 10, + 'train_epochs': 12.6, 'device': 'cuda', 'data_location': 'data.pt', } @@ -118,8 +121,12 @@ def batch_normalize_images(input_images, mean, std): } ## Convert dataset to FP16 now for the rest of the process.... - data['train']['images'] = data['train']['images'].half() - data['eval']['images'] = data['eval']['images'].half() + data['train']['images'] = data['train']['images'].half().requires_grad_(False) + data['eval']['images'] = data['eval']['images'].half().requires_grad_(False) + + # Convert this to one-hot to support the usage of cutmix (or whatever strange label tricks/magic you desire!) + data['train']['targets'] = F.one_hot(data['train']['targets']).half() + data['eval']['targets'] = F.one_hot(data['eval']['targets']).half() torch.save(data, hyp['misc']['data_location']) @@ -129,7 +136,6 @@ def batch_normalize_images(input_images, mean, std): ## hyp dictionary, then we should be good. :) data = torch.load(hyp['misc']['data_location']) - ## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way ## of measuring other things. That said, measuring the preprocessing (outside of the padding) is still important to us. @@ -152,33 +158,55 @@ def __init__(self, num_features, eps=1e-12, momentum=hyp['net']['batch_norm_mome self.bias.requires_grad = bias # Allows us to set default arguments for the whole convolution itself. +# Having an outer class like this does add space and complexity but offers us +# a ton of freedom when it comes to hacking in unique functionality for each layer type class Conv(nn.Conv2d): - def __init__(self, *args, **kwargs): + def __init__(self, *args, norm=False, **kwargs): kwargs = {**default_conv_kwargs, **kwargs} super().__init__(*args, **kwargs) self.kwargs = kwargs + self.norm = norm + + def forward(self, x): + if self.training and self.norm: + # TODO: Do/should we always normalize along dimension 1 of the weight vector(s), or the height x width dims too? + with torch.no_grad(): + F.normalize(self.weight.data, p=self.norm) + return super().forward(x) + +class Linear(nn.Linear): + def __init__(self, *args, norm=False, **kwargs): + super().__init__(*args, **kwargs) + self.kwargs = kwargs + self.norm = norm + + def forward(self, x): + if self.training and self.norm: + # TODO: Normalize on dim 1 or dim 0 for this guy? + with torch.no_grad(): + F.normalize(self.weight.data, p=self.norm) + return super().forward(x) # can hack any changes to each residual group that you want directly in here class ConvGroup(nn.Module): - def __init__(self, channels_in, channels_out, pool): + def __init__(self, channels_in, channels_out): super().__init__() - self.pool = pool # todo: maybe we can condense this later - self.channels_in = channels_in self.channels_out = channels_out self.pool1 = nn.MaxPool2d(2) self.conv1 = Conv(channels_in, channels_out) self.conv2 = Conv(channels_out, channels_out) + self.norm1 = BatchNorm(channels_out) self.norm2 = BatchNorm(channels_out) + self.activ = nn.GELU() def forward(self, x): x = self.conv1(x) - if self.pool: - x = self.pool1(x) + x = self.pool1(x) x = self.norm1(x) x = self.activ(x) residual = x @@ -186,7 +214,6 @@ def forward(self, x): x = self.norm2(x) x = self.activ(x) x = x + residual # haiku - return x class TemperatureScaler(nn.Module): @@ -195,10 +222,6 @@ def __init__(self, init_val): self.scaler = torch.tensor(init_val) def forward(self, x): - x.float() ## save precision for the gradients in the backwards pass - ## I personally believe from experience that this is important - ## for a few reasons. I believe this is the main functional difference between - ## my implementation, and David's implementation... return x.mul(self.scaler) class FastGlobalMaxPooling(nn.Module): @@ -274,7 +297,7 @@ def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=T scaler = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict depths = { 'init': round(scaler**-1*hyp['net']['base_depth']), # 32 w/ scaler at base value - 'block1': round(scaler** 1*hyp['net']['base_depth']), # 128 w/ scaler at base value + 'block1': round(scaler** 0*hyp['net']['base_depth']), # 64 w/ scaler at base value 'block2': round(scaler** 2*hyp['net']['base_depth']), # 256 w/ scaler at base value 'block3': round(scaler** 3*hyp['net']['base_depth']), # 512 w/ scaler at base value 'num_classes': 10 @@ -291,7 +314,6 @@ def forward(self, x): x = torch.cat((x, torch.flip(x, (-1,)))) x = self.net_dict['initial_block']['whiten'](x) x = self.net_dict['initial_block']['project'](x) - x = self.net_dict['initial_block']['norm'](x) x = self.net_dict['initial_block']['activation'](x) x = self.net_dict['residual1'](x) x = self.net_dict['residual2'](x) @@ -312,15 +334,14 @@ def make_net(): network_dict = nn.ModuleDict({ 'initial_block': nn.ModuleDict({ 'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0), - 'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1), - 'norm': BatchNorm(depths['init'], weight=False), + 'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1, norm=2.2), # The norm argument means we renormalize the weights to be length 1 for this as the power for the norm, each step 'activation': nn.GELU(), }), - 'residual1': ConvGroup(depths['init'], depths['block1'], pool=True), - 'residual2': ConvGroup(depths['block1'], depths['block2'], pool=True), - 'residual3': ConvGroup(depths['block2'], depths['block3'], pool=True), + 'residual1': ConvGroup(depths['init'], depths['block1']), + 'residual2': ConvGroup(depths['block1'], depths['block2']), + 'residual3': ConvGroup(depths['block2'], depths['block3']), 'pooling': FastGlobalMaxPooling(), - 'linear': nn.Linear(depths['block3'], depths['num_classes'], bias=False), + 'linear': Linear(depths['block3'], depths['num_classes'], bias=False, norm=5.), 'temperature': TemperatureScaler(hyp['opt']['scaling_factor']) }) @@ -330,6 +351,7 @@ def make_net(): net.train() net.half() # Convert network to half before initializing the initial whitening layer. + ## Initialize the whitening convolution with torch.no_grad(): # Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit hypersphere. (i.e. their...average vector length is 1.?, IIRC) @@ -373,14 +395,20 @@ def make_random_square_masks(inputs, mask_size): return final_mask -def batch_cutout(inputs, patch_size): + +def batch_cutmix(inputs, targets, patch_size): with torch.no_grad(): - cutout_batch_mask = make_random_square_masks(inputs, patch_size) - if cutout_batch_mask is None: - return inputs # if the mask is None, then that's because the patch size was set to 0 and we will not be using cutout today. - # TODO: Could be fused with the crop operation for sheer speeeeeds. :D <3 :)))) - cutout_batch = torch.where(cutout_batch_mask, torch.zeros_like(inputs), inputs) - return cutout_batch + batch_permuted = torch.randperm(inputs.shape[0], device='cuda') + cutmix_batch_mask = make_random_square_masks(inputs, patch_size) + if cutmix_batch_mask is None: + return inputs, targets # if the mask is None, then that's because the patch size was set to 0 and we will not be using cutmix today. + # We draw other samples from inside of the same batch + cutmix_batch = torch.where(cutmix_batch_mask, torch.index_select(inputs, 0, batch_permuted), inputs) + cutmix_targets = torch.index_select(targets, 0, batch_permuted) + # Get the percentage of each target to mix for the labels by the % proportion of pixels in the mix + portion_mixed = float(patch_size**2)/(inputs.shape[-2]*inputs.shape[-1]) + cutmix_labels = portion_mixed * cutmix_targets + (1. - portion_mixed) * targets + return cutmix_batch, cutmix_labels def batch_crop(inputs, crop_size): with torch.no_grad(): @@ -399,18 +427,18 @@ def batch_flip_lr(batch_images, flip_chance=.5): ######################################## class NetworkEMA(nn.Module): - def __init__(self, net, decay): + def __init__(self, net): super().__init__() # init the parent module so this module is registered properly self.net_ema = copy.deepcopy(net).eval().requires_grad_(False) # copy the model - self.decay = decay ## you can update/hack this as necessary for update scheduling purposes :3 - def update(self, current_net): + def update(self, current_net, decay): with torch.no_grad(): for ema_net_parameter, (parameter_name, incoming_net_parameter) in zip(self.net_ema.state_dict().values(), current_net.state_dict().items()): # potential bug: assumes that the network architectures don't change during training (!!!!) if incoming_net_parameter.dtype in (torch.half, torch.float): - ema_net_parameter.mul_(self.decay).add_(incoming_net_parameter.detach().mul(1. - self.decay)) # update the ema values in place, similar to how optimizer momentum is coded - if not 'running' in parameter_name: - incoming_net_parameter = ema_net_parameter.detach() + ema_net_parameter.mul_(decay).add_(incoming_net_parameter.detach().mul(1. - decay)) # update the ema values in place, similar to how optimizer momentum is coded + # And then we also copy the parameters back to the network, similarly to the Lookahead optimizer (but with a much more aggressive-at-the-end schedule) + if not ('norm' in parameter_name and 'weight' in parameter_name) and not 'whiten' in parameter_name: + incoming_net_parameter.copy_(ema_net_parameter.detach()) def forward(self, inputs): with torch.no_grad(): @@ -418,9 +446,12 @@ def forward(self, inputs): # TODO: Could we jit this in the (more distant) future? :) @torch.no_grad() -def get_batches(data_dict, key, batchsize): +def get_batches(data_dict, key, batchsize, epoch_fraction=1., cutmix_size=None): num_epoch_examples = len(data_dict[key]['images']) shuffled = torch.randperm(num_epoch_examples, device='cuda') + if epoch_fraction < 1: + shuffled = shuffled[:batchsize * round(epoch_fraction * shuffled.shape[0]/batchsize)] # TODO: Might be slightly inaccurate, let's fix this later... :) :D :confetti: :fireworks: + num_epoch_examples = shuffled.shape[0] crop_size = 32 ## Here, we prep the dataset by applying all data augmentations in batches ahead of time before each epoch, then we return an iterator below ## that iterates in chunks over with a random derangement (i.e. shuffled indices) of the individual examples. So we get perfectly-shuffled @@ -428,21 +459,22 @@ def get_batches(data_dict, key, batchsize): if key == 'train': images = batch_crop(data_dict[key]['images'], crop_size) # TODO: hardcoded image size for now? images = batch_flip_lr(images) - images = batch_cutout(images, patch_size=hyp['net']['cutout_size']) + images, targets = batch_cutmix(images, data_dict[key]['targets'], patch_size=cutmix_size) else: images = data_dict[key]['images'] + targets = data_dict[key]['targets'] # Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training images = images.to(memory_format=torch.channels_last) for idx in range(num_epoch_examples // batchsize): if not (idx+1)*batchsize > num_epoch_examples: ## Use the shuffled randperm to assemble individual items into a minibatch yield images.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]), \ - data_dict[key]['targets'].index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]) ## Each item is only used/accessed by the network once per epoch. :D + targets.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]) ## Each item is only used/accessed by the network once per epoch. :D def init_split_parameter_dictionaries(network): - params_non_bias = {'params': [], 'lr': hyp['opt']['non_bias_lr'], 'momentum': .85, 'nesterov': True, 'weight_decay': hyp['opt']['non_bias_decay']} - params_bias = {'params': [], 'lr': hyp['opt']['bias_lr'], 'momentum': .85, 'nesterov': True, 'weight_decay': hyp['opt']['bias_decay']} + params_non_bias = {'params': [], 'lr': hyp['opt']['non_bias_lr'], 'momentum': .85, 'nesterov': True, 'weight_decay': hyp['opt']['non_bias_decay'], 'foreach': True} + params_bias = {'params': [], 'lr': hyp['opt']['bias_lr'], 'momentum': .85, 'nesterov': True, 'weight_decay': hyp['opt']['bias_decay'], 'foreach': True} for name, p in network.named_parameters(): if p.requires_grad: @@ -491,10 +523,8 @@ def main(): # TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize).... num_steps_per_epoch = len(data['train']['images']) // batchsize - total_train_steps = num_steps_per_epoch * hyp['misc']['train_epochs'] - ema_epoch_start = hyp['misc']['train_epochs'] - hyp['misc']['ema']['epochs'] - num_cooldown_before_freeze_steps = 0 - num_low_lr_steps_for_ema = hyp['misc']['ema']['epochs'] * num_steps_per_epoch + total_train_steps = math.ceil(num_steps_per_epoch * hyp['misc']['train_epochs']) + ema_epoch_start = math.floor(hyp['misc']['train_epochs']) - hyp['misc']['ema']['epochs'] ## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps ## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we @@ -517,8 +547,9 @@ def main(): ## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 -- ## This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training) initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D - lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*1e-24), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False) - lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*1e-24), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False) + final_lr_ratio = .05 # Actually pretty important, apparently! + lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False) + lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False) ## For accurately timing GPU code starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) @@ -526,7 +557,7 @@ def main(): if True: ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent - for epoch in range(hyp['misc']['train_epochs']): + for epoch in range(math.ceil(hyp['misc']['train_epochs'])): ################# # Training Mode # ################# @@ -537,10 +568,13 @@ def main(): loss_train = None accuracy_train = None - for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=batchsize)): + cutmix_size = hyp['net']['cutmix_size'] if epoch >= hyp['misc']['train_epochs'] - hyp['net']['cutmix_epochs'] else 0 + epoch_fraction = 1 if epoch + 1 < hyp['misc']['train_epochs'] else hyp['misc']['train_epochs'] % 1 # We need to know if we're running a partial epoch or not. + + for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=batchsize, epoch_fraction=epoch_fraction, cutmix_size=cutmix_size)): ## Run everything through the network outputs = net(inputs) - + loss_batchsize_scaler = 512/batchsize # to scale to keep things at a relatively similar amount of regularization when we change our batchsize since we're summing over the whole batch ## If you want to add other losses or hack around with the loss, you can do that here. loss = loss_fn(outputs, targets).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling @@ -548,7 +582,7 @@ def main(): # we only take the last-saved accs and losses from train if epoch_step % 50 == 0: - train_acc = (outputs.detach().argmax(-1) == targets).float().mean().item() + train_acc = (outputs.detach().argmax(-1) == targets.argmax(-1)).float().mean().item() train_loss = loss.detach().cpu().item()/(batchsize*loss_batchsize_scaler) loss.backward() @@ -568,10 +602,12 @@ def main(): if epoch >= ema_epoch_start and current_steps % hyp['misc']['ema']['every_n_steps'] == 0: ## Initialize the ema from the network at this point in time if it does not already exist.... :D - if net_ema is None or epoch_step < num_cooldown_before_freeze_steps: # don't snapshot the network yet if so! - net_ema = NetworkEMA(net, decay=projected_ema_decay_val) + if net_ema is None: # don't snapshot the network yet if so! + net_ema = NetworkEMA(net) continue - net_ema.update(net) + # We warm up our ema's decay/momentum value over training exponentially according to the hyp config dictionary (this lets us move fast, then average strongly at the end). + net_ema.update(net, decay=projected_ema_decay_val*(current_steps/total_train_steps)**hyp['misc']['ema']['decay_pow']) + ender.record() torch.cuda.synchronize() total_time_seconds += 1e-3 * starter.elapsed_time(ender) @@ -581,7 +617,7 @@ def main(): #################### net.eval() - eval_batchsize = 1000 + eval_batchsize = 2500 assert data['eval']['images'].shape[0] % eval_batchsize == 0, "Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)." loss_list_val, acc_list, acc_list_ema = [], [], [] @@ -589,10 +625,10 @@ def main(): for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize): if epoch >= ema_epoch_start: outputs = net_ema(inputs) - acc_list_ema.append((outputs.argmax(-1) == targets).float().mean()) + acc_list_ema.append((outputs.argmax(-1) == targets.argmax(-1)).float().mean()) outputs = net(inputs) loss_list_val.append(loss_fn(outputs, targets).float().mean()) - acc_list.append((outputs.argmax(-1) == targets).float().mean()) + acc_list.append((outputs.argmax(-1) == targets.argmax(-1)).float().mean()) val_acc = torch.stack(acc_list).mean().item() ema_val_acc = None @@ -611,10 +647,8 @@ def main(): # Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....) ## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round. - print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch == hyp['misc']['train_epochs'] - 1)) - return val_acc # Return the final non-ema accuracy achieved (not using the 'best accuracy' selection strategy, which I think is okay here....) - # Note: For longer runs with much larer models, you may want to switch to the 'val_ema_acc' metric. This is because - # that metric does much better outside of these extremely rapid training runs. + print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch >= math.ceil(hyp['misc']['train_epochs'] - 1))) + return ema_val_acc # Return the final ema accuracy achieved (not using the 'best accuracy' selection strategy, which I think is okay here....) if __name__ == "__main__": acc_list = []