Skip to content

Commit

Permalink
Committing changes -- memory format and max pooling for a change from…
Browse files Browse the repository at this point in the history
… ~18.1 to ~12.31-12.38s, a new world record! 🎆 🎆 🐧 🎆
  • Loading branch information
tysam-code committed Jan 15, 2023
1 parent 4923fdd commit 01603a8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 65 deletions.
106 changes: 44 additions & 62 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
bias_scaler = 64
hyp = {
'opt': {
'bias_lr': 1. * bias_scaler/batchsize, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :'))))
'non_bias_lr': 1. / batchsize,
'bias_decay': 5e-4 * batchsize/bias_scaler,
'non_bias_decay': 5e-4 * batchsize,
'bias_lr': 1.35 * 1. * bias_scaler/batchsize, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :'))))
'non_bias_lr': 1.35 * 1. / batchsize,
'bias_decay': 4.8e-4 * batchsize/bias_scaler,
'non_bias_decay': 4.8e-4 * batchsize,
'scaling_factor': 1./16,
'percent_start': .2,
},
Expand All @@ -53,15 +53,15 @@
'kernel_size': 3,
'num_examples': 10000,
},
'ghost_norm_group_size': 64, ## Regularization
'batch_norm_momentum': .4,
'cutout_size': 0,
'pad_amount': 4,
},
'misc': {
'ema': {
'epochs': 2,
'decay_base': .99,
'every_n_steps': 5,
'epochs': 3,
'decay_base': .987,
'every_n_steps': 2,
},
'train_epochs': 10,
'device': 'cuda',
Expand Down Expand Up @@ -124,6 +124,7 @@ def batch_normalize_images(input_images, mean, std):
## hyp dictionary, then we should be good. :)
data = torch.load(hyp['misc']['data_location'])


## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way
## of measuring other things. That said, measuring the preprocessing (outside of the padding) is still important to us.

Expand All @@ -136,48 +137,22 @@ def batch_normalize_images(input_images, mean, std):
# Network Components #
#############################################

# We might be able to fuse this weight and save some memory/runtime/etc, since the fast version of the network doesn't need it I thinks...
# We might be able to fuse this weight and save some memory/runtime/etc, since the fast version of the network might be able to do without somehow....
class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight=False, bias=True):
def __init__(self, num_features, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'], weight=False, bias=True):
super().__init__(num_features, eps=eps, momentum=momentum)
self.weight.data.fill_(1.0)
self.bias.data.fill_(0.0)
self.weight.requires_grad = weight
self.bias.requires_grad = bias

class GhostNorm(BatchNorm):
def __init__(self, num_features, num_splits=batchsize//hyp['net']['ghost_norm_group_size'], **kw):
super().__init__(num_features, **kw)
self.num_splits = num_splits
self.register_buffer('running_mean', torch.zeros(num_features*self.num_splits))
self.register_buffer('running_var', torch.ones(num_features*self.num_splits))

def train(self, mode=True):
if (self.training is True) and (mode is False): #lazily collate stats when we are going to use them, i.e., when we switch from the 'train' to 'eval' modes
self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)
self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)
return super().train(mode)

def forward(self, input):
N, C, H, W = input.shape
if self.training or not self.track_running_stats:
return torch.nn.functional.batch_norm(
input.view(-1, C*self.num_splits, H, W), self.running_mean, self.running_var,
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
True, self.momentum, self.eps).view(N, C, H, W)
else:
return torch.nn.functional.batch_norm(
input, self.running_mean[:self.num_features], self.running_var[:self.num_features],
self.weight, self.bias, False, self.momentum, self.eps)

# Allows us to set default arguments for the whole convolution itself.
class Conv(nn.Conv2d):
def __init__(self, *args, **kwargs):
kwargs = {**default_conv_kwargs, **kwargs}
super().__init__(*args, **kwargs)
self.kwargs = kwargs


# can hack any changes to each residual group that you want directly in here
class ConvGroup(nn.Module):
def __init__(self, channels_in, channels_out, residual, short, pool):
Expand All @@ -191,16 +166,15 @@ def __init__(self, channels_in, channels_out, residual, short, pool):

self.conv1 = Conv(channels_in, channels_out)
self.pool1 = nn.MaxPool2d(2)
self.norm1 = GhostNorm(channels_out)
self.norm1 = BatchNorm(channels_out)
self.activ = nn.CELU(alpha=.3)

# note: this has to be flat if we're jitting things.... we just might burn a bit of extra GPU mem if so
if not short:
self.conv2 = Conv(channels_out, channels_out)
self.conv3 = Conv(channels_out, channels_out)
self.norm2 = GhostNorm(channels_out)
self.norm3 = GhostNorm(channels_out)

self.norm2 = BatchNorm(channels_out)
self.norm3 = BatchNorm(channels_out)

def forward(self, x):
x = self.conv1(x)
Expand Down Expand Up @@ -234,6 +208,15 @@ def forward(self, x):
## my implementation, and David's implementation...
return x.mul(self.scaler)

class FastGlobalMaxPooling(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
# Previously was chained torch.max calls.
# requires less time than AdaptiveMax2dPooling -- about ~.3s for the entire run, in fact (which is pretty significant! :O :D :O :O <3 <3 <3 <3)
return torch.amax(x, dim=(2,3)) # Global maximum pooling

#############################################
# Init Helper Functions #
#############################################
Expand Down Expand Up @@ -288,7 +271,7 @@ def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2):
class SpeedyResNet(nn.Module):
def __init__(self, network_dict):
super().__init__()
self.net_dict = network_dict # flexible, defined in the make_network function
self.net_dict = network_dict # flexible, defined in the make_net function

# This allows you to customize/change the execution order of the network as needed.
def forward(self, x):
Expand All @@ -302,7 +285,6 @@ def forward(self, x):
x = self.net_dict['residual2'](x)
x = self.net_dict['residual3'](x)
x = self.net_dict['pooling'](x)
x = self.net_dict['reshape'](x)
x = self.net_dict['linear'](x)
x = self.net_dict['temperature'](x)
if not self.training:
Expand All @@ -319,26 +301,26 @@ def make_net():
'initial_block': nn.ModuleDict({
'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size']),
'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1),
'norm': GhostNorm(depths['init'], weight=False),
'norm': BatchNorm(depths['init'], weight=False),
'activation': nn.CELU(alpha=.3),
}),
'residual1': ConvGroup(depths['init'], depths['block1'], residual=True, short=False, pool=True),
'residual2': ConvGroup(depths['block1'], depths['block2'], residual=True, short=True, pool=True),
'residual3': ConvGroup(depths['block2'], depths['block3'], residual=True, short=False, pool=True),
'pooling': nn.AdaptiveMaxPool2d((1, 1)),
'reshape': nn.Flatten(),
'pooling': FastGlobalMaxPooling(),
'linear': nn.Linear(depths['block3'], depths['num_classes'], bias=False),
'temperature': TemperatureScaler(hyp['opt']['scaling_factor'])
})

net = SpeedyResNet(network_dict)
net = net.to(hyp['misc']['device'])
net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training
net.train()
net.half() # Convert network to half before initializing the initial whitening layer.

## Initialize the whitening convolution
with torch.no_grad():
# Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit hypersphere. (i.e. their vector length is 1., IIRC)
# Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit hypersphere. (i.e. their...average vector length is 1.?, IIRC)
init_whitening_conv(net.net_dict['initial_block']['whiten'],
data['train']['images'].index_select(0, torch.randperm(data['train']['images'].shape[0], device=data['train']['images'].device)),
num_examples=hyp['net']['whitening']['num_examples'],
Expand Down Expand Up @@ -392,7 +374,7 @@ def batch_crop(inputs, crop_size):

def batch_flip_lr(batch_images, flip_chance=.5):
with torch.no_grad():
# TODO: More elegant way to do this? :') :'((((
# TODO: Is there a more elegant way to do this? :') :'((((
return torch.where(torch.rand_like(batch_images[:, 0, 0, 0].view(-1, 1, 1, 1)) < flip_chance, torch.flip(batch_images, (-1,)), batch_images)


Expand All @@ -416,7 +398,8 @@ def forward(self, inputs):
with torch.no_grad():
return self.net_ema(inputs)

# TODO: Can we jit this in the (more distant) future? :)
# TODO: Could we jit this in the (more distant) future? :)
@torch.no_grad()
def get_batches(data_dict, key, batchsize):
num_epoch_examples = len(data_dict[key]['images'])
shuffled = torch.randperm(num_epoch_examples, device='cuda')
Expand All @@ -431,6 +414,8 @@ def get_batches(data_dict, key, batchsize):
else:
images = data_dict[key]['images']

# Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training
images = images.to(memory_format=torch.channels_last)
for idx in range(num_epoch_examples // batchsize):
if not (idx+1)*batchsize > num_epoch_examples: ## Use the shuffled randperm to assemble individual items into a minibatch
yield images.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]), \
Expand Down Expand Up @@ -472,13 +457,12 @@ def print_training_details(columns_list, separator_left='| ', separator_right='
if is_final_entry:
print('-'*(len(print_string))) # print the final output bar

print_training_details(logging_columns_list, column_heads_only=True) # print out the training column heads.
print_training_details(logging_columns_list, column_heads_only=True) ## print out the training column heads before we print the actual content for each run.

########################################
# Train and Eval #
########################################

# to do cast to fp16 precision for training
def main():
# Initializing constants for the whole run.
net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training
Expand All @@ -488,13 +472,14 @@ def main():
current_steps = 0.

# TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize)....
num_steps_per_epoch = len(data['train']['images']) // batchsize # todo: a bit of a tad of cleanup here. ::::))) :>>>
num_steps_per_epoch = len(data['train']['images']) // batchsize
total_train_steps = num_steps_per_epoch * hyp['misc']['train_epochs']
ema_epoch_start = hyp['misc']['train_epochs'] - hyp['misc']['ema']['epochs']
num_low_lr_steps_for_ema = hyp['misc']['ema']['epochs'] * num_steps_per_epoch
## TODO: (check? <# :)))) ) I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps

## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps
## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we
## are intaking fewer samples of our distribution-over-time, with a higher individual weight each.
## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. This can be good or bad depending upon what we want.
projected_ema_decay_val = hyp['misc']['ema']['decay_base'] ** hyp['misc']['ema']['every_n_steps']

# Adjust pct_start based upon how many epochs we need to finetune the ema at a low lr for
Expand Down Expand Up @@ -540,7 +525,7 @@ def main():
for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=batchsize)):
## Run everything through the network
outputs = net(inputs)


## If you want to add other losses or hack around with the loss, you can do that here.
loss = loss_fn(outputs, targets).sum() ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling
Expand All @@ -561,19 +546,17 @@ def main():
# We only want to step the lr_schedulers while we have training steps to consume. Otherwise we get a not-so-friendly error from PyTorch
lr_sched.step()
lr_sched_bias.step()

## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method
opt.zero_grad(set_to_none=True)
opt_bias.zero_grad(set_to_none=True)

current_steps += 1

if epoch >= ema_epoch_start and current_steps % hyp['misc']['ema']['every_n_steps'] == 0:
## Initialize the ema from the network at this point in time if it does not already exist.... :D
if net_ema is None:
net_ema = NetworkEMA(net, decay=projected_ema_decay_val)
net_ema.update(net)

ender.record()
torch.cuda.synchronize()
total_time_seconds += 1e-3 * starter.elapsed_time(ender)
Expand All @@ -588,11 +571,10 @@ def main():
loss_list_val, acc_list, acc_list_ema = [], [], []

with torch.no_grad():
# TODO: Copy is probably slow, we can def avoid this somehow, I think....
for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize):
if epoch >= ema_epoch_start:
outputs = net_ema(inputs)
acc_list_ema.append((outputs.argmax(-1) == targets).float().mean())
outputs = net_ema(inputs)
acc_list_ema.append((outputs.argmax(-1) == targets).float().mean())
outputs = net(inputs)
loss_list_val.append(loss_fn(outputs, targets).float().mean())
acc_list.append((outputs.argmax(-1) == targets).float().mean())
Expand All @@ -610,12 +592,12 @@ def main():
format_for_table = lambda x, locals: (f"{locals[x]}".rjust(len(x))) \
if type(locals[x]) == int else "{:0.4f}".format(locals[x]).rjust(len(x)) \
if locals[x] is not None \
else " "*len(x)
else " "*len(x)

# Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....)
## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round.
print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch == hyp['misc']['train_epochs'] - 1))

if __name__ == "__main__":
for run_num in range(5):
for run_num in range(25):
main()
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
torch
torchvision
numpy
ipython
rich

0 comments on commit 01603a8

Please sign in to comment.