diff --git a/S3-Training/README.md b/S3-Training/README.md
index 8ebfa6e0..8d06a5a7 100644
--- a/S3-Training/README.md
+++ b/S3-Training/README.md
@@ -1,11 +1,12 @@
-# S3: Sign-Sparse-Shift Reparametrization for Effective Training of Low-bit Shift Networks
+# DenseShift: Towards Accurate and Efficient Low-Bit Power-of-Two Quantization
-This repository is the DEMO code of the NeurIPS 2021 paper [S3: Sign-Sparse-Shift Reparametrization for Effective Training of Low-bit Shift Networks](https://proceedings.neurips.cc/paper/2021/file/7a1d9028a78f418cb8f01909a348d9b2-Paper.pdf).
+The DEMO code of the ICCV-2023 paper [DenseShift: Towards Accurate and Efficient Low-Bit Power-of-Two Quantization](https://openaccess.thecvf.com/content/ICCV2023/html/Li_DenseShift_Towards_Accurate_and_Efficient_Low-Bit_Power-of-Two_Quantization_ICCV_2023_paper.html).
-Shift neural networks (Power-of-Two quantization) reduce computation complexity by removing expensive multiplication operations and quantizing continuous weights into low-bit discrete values, which are fast and energy-efficient compared to conventional neural networks. However, existing shift networks are sensitive to the weight initialization and yield a degraded performance caused by vanishing gradient and weight sign freezing problem. To address these issues, we propose S3 re-parameterization, a novel technique for training low-bit shift networks. Our method decomposes a discrete parameter in a sign-sparse-shift 3-fold manner. This way, it efficiently learns a low-bit network with weight dynamics similar to full-precision networks and insensitive to weight initialization. Our proposed training method pushes the boundaries of shift neural networks and shows 3-bit shift networks compete with their full-precision counterparts in terms of top-1 accuracy on ImageNet.
+## Abstract
+Efficiently deploying deep neural networks on low-resource edge devices is challenging due to their ever-increasing resource requirements. To address this issue, researchers have proposed multiplication-free neural networks, such as Power-of-Two quantization, or also known as Shift networks, which aim to reduce memory usage and simplify computation. However, existing low-bit Shift networks are not as accurate as their full-precision counterparts, typically suffering from limited weight range encoding schemes and quantization loss. In this paper, we propose the DenseShift network, which significantly improves the accuracy of Shift networks, achieving competitive performance to full-precision networks for vision and speech applications. In addition, we introduce a method to deploy an efficient DenseShift network using non-quantized floating-point activations, while obtaining 1.6X speed-up over existing methods. To achieve this, we demonstrate that zero-weight values in low-bit Shift networks do not contribute to model capacity and negatively impact inference computation. To address this issue, we propose a zero-free shifting mechanism that simplifies inference and increases model capacity. We further propose a sign-scale decomposition design to enhance training efficiency and a low-variance random initialization strategy to improve the model's transfer learning performance. Our extensive experiments on various computer vision and speech tasks demonstrate that DenseShift outperforms existing low-bit multiplication-free networks and achieves competitive performance compared to full-precision networks. Furthermore, our proposed approach exhibits strong transfer learning performance without a drop in accuracy.
-
+
## Requirements
@@ -18,206 +19,11 @@ Shift neural networks (Power-of-Two quantization) reduce computation complexity
- Download the ImageNet dataset from http://www.image-net.org/
- Then, and move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh)
-## Training from scratch
+## Run
-### 3-bit Shift Network on ResNet-18 ImageNet
+### 3-bit DenseShift Network on ResNet-18 ImageNet
-To train a 3bit S3 re-parameterized shift network with ResNet18 on ImageNet from scratch, run:
+To train a 3bit DenseShift ResNet-18 on ImageNet from scratch, run:
```train
python main.py /path/to/imagenet
```
-
-## Pre-trained checkpoints
-
-Two pre-trained checkpoints corresponding to the 3-bit results reported in the paper can be [downloaded here](./pre-trained-ckpt-evaluation/). The checkpoints convert into a format compatible with the PyTorch official ImageNet training example so that the standard implementation code can evaluate the validation accuracy of the checkpoints.
-
-### 3-bit Shift Network on ResNet-18 ImageNet
-
-To evaluate the pre-trained checkpoint of 3bit S3 re-parameterized shift network with ResNet-18 on ImageNet, run:
-
-```eval
-python main_eval.py --evaluate --resume s3-3bit-resnet18-pytorch-imagenet.pth.tar --arch resnet18 /path/to/imagenet
-```
-
-Outputs:
-```example_output
-=> creating model 'resnet18'
-=> loading checkpoint 's3-3bit-resnet18-pytorch-imagenet.pth.tar'
-=> loaded checkpoint 's3-3bit-resnet18-pytorch-imagenet.pth.tar' (epoch 199)
-Test: [ 0/196] Time 4.506 ( 4.506) Loss 6.7311e-01 (6.7311e-01) Acc@1 82.81 ( 82.81) Acc@5 96.09 ( 96.09)
-Test: [ 10/196] Time 0.072 ( 0.986) Loss 1.2426e+00 (9.1206e-01) Acc@1 69.14 ( 77.73) Acc@5 88.67 ( 92.68)
-Test: [ 20/196] Time 2.218 ( 0.904) Loss 8.9948e-01 (9.1190e-01) Acc@1 81.64 ( 77.40) Acc@5 91.02 ( 92.47)
-Test: [ 30/196] Time 0.072 ( 0.809) Loss 8.3274e-01 (8.8309e-01) Acc@1 80.47 ( 77.95) Acc@5 94.92 ( 92.97)
-Test: [ 40/196] Time 2.417 ( 0.815) Loss 9.7517e-01 (9.2135e-01) Acc@1 75.78 ( 76.59) Acc@5 94.14 ( 93.16)
-Test: [ 50/196] Time 0.072 ( 0.775) Loss 6.4144e-01 (9.1274e-01) Acc@1 83.59 ( 76.49) Acc@5 96.88 ( 93.35)
-Test: [ 60/196] Time 2.033 ( 0.795) Loss 1.1415e+00 (9.1855e-01) Acc@1 70.31 ( 76.14) Acc@5 91.02 ( 93.45)
-Test: [ 70/196] Time 0.072 ( 0.791) Loss 9.0677e-01 (9.0656e-01) Acc@1 73.83 ( 76.42) Acc@5 94.14 ( 93.54)
-Test: [ 80/196] Time 0.848 ( 0.784) Loss 1.7171e+00 (9.2979e-01) Acc@1 57.03 ( 75.85) Acc@5 85.55 ( 93.26)
-Test: [ 90/196] Time 1.443 ( 0.785) Loss 2.2276e+00 (9.9383e-01) Acc@1 51.56 ( 74.53) Acc@5 75.78 ( 92.41)
-Test: [100/196] Time 1.244 ( 0.767) Loss 1.7705e+00 (1.0593e+00) Acc@1 55.47 ( 73.16) Acc@5 82.03 ( 91.58)
-Test: [110/196] Time 0.449 ( 0.771) Loss 1.2247e+00 (1.0864e+00) Acc@1 70.70 ( 72.59) Acc@5 89.84 ( 91.17)
-Test: [120/196] Time 0.074 ( 0.763) Loss 1.9402e+00 (1.1115e+00) Acc@1 55.86 ( 72.25) Acc@5 76.95 ( 90.76)
-Test: [130/196] Time 0.071 ( 0.774) Loss 1.0368e+00 (1.1486e+00) Acc@1 74.61 ( 71.40) Acc@5 92.97 ( 90.29)
-Test: [140/196] Time 0.072 ( 0.754) Loss 1.4686e+00 (1.1709e+00) Acc@1 65.23 ( 70.97) Acc@5 83.98 ( 90.00)
-Test: [150/196] Time 0.073 ( 0.763) Loss 1.4905e+00 (1.1954e+00) Acc@1 69.92 ( 70.51) Acc@5 85.16 ( 89.62)
-Test: [160/196] Time 0.073 ( 0.754) Loss 1.1636e+00 (1.2138e+00) Acc@1 73.44 ( 70.19) Acc@5 89.84 ( 89.35)
-Test: [170/196] Time 0.072 ( 0.755) Loss 7.5062e-01 (1.2348e+00) Acc@1 77.73 ( 69.75) Acc@5 96.48 ( 89.05)
-Test: [180/196] Time 0.073 ( 0.745) Loss 1.3958e+00 (1.2521e+00) Acc@1 64.06 ( 69.37) Acc@5 88.67 ( 88.84)
-Test: [190/196] Time 0.072 ( 0.748) Loss 1.2849e+00 (1.2503e+00) Acc@1 64.84 ( 69.33) Acc@5 92.58 ( 88.89)
- * Acc@1 69.508 Acc@5 88.968
-```
-
-The elements of following weight tensors in the checkpoint are restricted to the discrete weight values of 3-bit shift network {-4, -2, -1, 0, 1, 2, 4}
-
- Quantized tensor name in ResNet-18 checkpoint
-module.layer1.0.conv1.weight
-module.layer1.0.conv2.weight
-module.layer1.1.conv1.weight
-module.layer1.1.conv2.weight
-module.layer2.0.conv1.weight
-module.layer2.0.conv2.weight
-module.layer2.0.downsample.0.weight
-module.layer2.1.conv1.weight
-module.layer2.1.conv2.weight
-module.layer3.0.conv1.weight
-module.layer3.0.conv2.weight
-module.layer3.0.downsample.0.weight
-module.layer3.1.conv1.weight
-module.layer3.1.conv2.weight
-module.layer4.0.conv1.weight
-module.layer4.0.conv2.weight
-module.layer4.0.downsample.0.weight
-module.layer4.1.conv1.weight
-module.layer4.1.conv2.weight
-
-
-The following code snippet can load a discrete weight tensor from the checkpoint and output the unique discrete values in this tensor.
-```eval
-import torch
-TENSOR_NAME = "module.layer1.0.conv1.weight"
-CKPT_NAME = "s3-3bit-resnet18-pytorch-imagenet.pth.tar"
-
-checkpoint = torch.load(CKPT_NAME)
-model_state_dict = checkpoint['state_dict']
-discrete_weight = model_state_dict[TENSOR_NAME]
-print(torch.unique(discrete_weight))
-```
-
-Outputs:
-```example_output
-tensor([-4., -2., -1., -0., 1., 2., 4.], device='cuda:0')
-```
-
-
-### 3-bit Shift Network on ResNet-50 ImageNet
-
-To evaluate the pre-trained checkpoint of 3bit S3 re-parameterized shift network with ResNet-50 on ImageNet, run:
-
-```eval
-python main_eval.py --evaluate --resume s3-3bit-resnet50-pytorch-imagenet.pth.tar --arch resnet50 /path/to/imagenet
-```
-
-Outputs:
-```example_output
-=> creating model 'resnet50'
-=> loading checkpoint 's3-3bit-resnet50-pytorch-imagenet.pth.tar'
-=> loaded checkpoint 's3-3bit-resnet50-pytorch-imagenet.pth.tar' (epoch 199)
-Test: [ 0/196] Time 4.976 ( 4.976) Loss 4.9636e-01 (4.9636e-01) Acc@1 86.33 ( 86.33) Acc@5 97.27 ( 97.27)
-Test: [ 10/196] Time 0.221 ( 0.972) Loss 1.0587e+00 (6.8706e-01) Acc@1 75.39 ( 82.07) Acc@5 92.19 ( 95.63)
-Test: [ 20/196] Time 1.160 ( 0.907) Loss 7.0471e-01 (6.8882e-01) Acc@1 86.33 ( 81.99) Acc@5 92.58 ( 95.48)
-Test: [ 30/196] Time 0.221 ( 0.873) Loss 8.0941e-01 (6.5377e-01) Acc@1 78.91 ( 83.09) Acc@5 94.92 ( 95.60)
-Test: [ 40/196] Time 2.344 ( 0.906) Loss 6.5837e-01 (6.8861e-01) Acc@1 82.03 ( 81.85) Acc@5 96.88 ( 95.61)
-Test: [ 50/196] Time 0.223 ( 0.829) Loss 4.6707e-01 (6.8241e-01) Acc@1 88.67 ( 81.78) Acc@5 96.88 ( 95.76)
-Test: [ 60/196] Time 1.323 ( 0.812) Loss 8.7407e-01 (6.9512e-01) Acc@1 74.22 ( 81.40) Acc@5 96.48 ( 95.87)
-Test: [ 70/196] Time 2.609 ( 0.832) Loss 7.4790e-01 (6.8027e-01) Acc@1 76.95 ( 81.63) Acc@5 96.88 ( 96.06)
-Test: [ 80/196] Time 0.221 ( 0.810) Loss 1.4313e+00 (7.0608e-01) Acc@1 65.23 ( 81.13) Acc@5 87.11 ( 95.75)
-Test: [ 90/196] Time 3.314 ( 0.842) Loss 1.8285e+00 (7.5399e-01) Acc@1 58.20 ( 80.08) Acc@5 85.94 ( 95.25)
-Test: [100/196] Time 0.219 ( 0.825) Loss 1.2244e+00 (8.0642e-01) Acc@1 66.80 ( 78.93) Acc@5 89.84 ( 94.59)
-Test: [110/196] Time 3.015 ( 0.847) Loss 8.3800e-01 (8.3314e-01) Acc@1 78.91 ( 78.41) Acc@5 94.92 ( 94.27)
-Test: [120/196] Time 0.219 ( 0.844) Loss 1.2821e+00 (8.4899e-01) Acc@1 71.48 ( 78.15) Acc@5 88.28 ( 94.02)
-Test: [130/196] Time 2.935 ( 0.857) Loss 6.7108e-01 (8.8153e-01) Acc@1 81.64 ( 77.40) Acc@5 95.31 ( 93.68)
-Test: [140/196] Time 0.222 ( 0.852) Loss 1.1377e+00 (8.9882e-01) Acc@1 72.27 ( 77.09) Acc@5 91.80 ( 93.49)
-Test: [150/196] Time 2.446 ( 0.858) Loss 1.1069e+00 (9.1730e-01) Acc@1 76.17 ( 76.76) Acc@5 90.62 ( 93.22)
-Test: [160/196] Time 0.220 ( 0.847) Loss 7.7915e-01 (9.3251e-01) Acc@1 83.20 ( 76.46) Acc@5 93.36 ( 93.00)
-Test: [170/196] Time 2.340 ( 0.852) Loss 5.5731e-01 (9.4940e-01) Acc@1 84.77 ( 76.01) Acc@5 96.88 ( 92.81)
-Test: [180/196] Time 0.221 ( 0.845) Loss 1.2214e+00 (9.6362e-01) Acc@1 67.97 ( 75.69) Acc@5 93.75 ( 92.70)
-Test: [190/196] Time 2.750 ( 0.848) Loss 1.1438e+00 (9.6272e-01) Acc@1 69.92 ( 75.63) Acc@5 94.53 ( 92.75)
- * Acc@1 75.748 Acc@5 92.800
-```
-
-The elements of following weight tensors in the checkpoint are restricted to the discrete weight values of 3-bit shift network {-4, -2, -1, 0, 1, 2, 4}
-
- Quantized tensor name in ResNet-50 checkpoint
-module.layer1.0.conv1.weight
-module.layer1.0.conv2.weight
-module.layer1.0.conv3.weight
-module.layer1.0.downsample.0.weight
-module.layer1.1.conv1.weight
-module.layer1.1.conv2.weight
-module.layer1.1.conv3.weight
-module.layer1.2.conv1.weight
-module.layer1.2.conv2.weight
-module.layer1.2.conv3.weight
-module.layer2.0.conv1.weight
-module.layer2.0.conv2.weight
-module.layer2.0.conv3.weight
-module.layer2.0.downsample.0.weight
-module.layer2.1.conv1.weight
-module.layer2.1.conv2.weight
-module.layer2.1.conv3.weight
-module.layer2.2.conv1.weight
-module.layer2.2.conv2.weight
-module.layer2.2.conv3.weight
-module.layer2.3.conv1.weight
-module.layer2.3.conv2.weight
-module.layer2.3.conv3.weight
-module.layer3.0.conv1.weight
-module.layer3.0.conv2.weight
-module.layer3.0.conv3.weight
-module.layer3.0.downsample.0.weight
-module.layer3.1.conv1.weight
-module.layer3.1.conv2.weight
-module.layer3.1.conv3.weight
-module.layer3.2.conv1.weight
-module.layer3.2.conv2.weight
-module.layer3.2.conv3.weight
-module.layer3.3.conv1.weight
-module.layer3.3.conv2.weight
-module.layer3.3.conv3.weight
-module.layer3.4.conv1.weight
-module.layer3.4.conv2.weight
-module.layer3.4.conv3.weight
-module.layer3.5.conv1.weight
-module.layer3.5.conv2.weight
-module.layer3.5.conv3.weight
-module.layer4.0.conv1.weight
-module.layer4.0.conv2.weight
-module.layer4.0.conv3.weight
-module.layer4.0.downsample.0.weight
-module.layer4.1.conv1.weight
-module.layer4.1.conv2.weight
-module.layer4.1.conv3.weight
-module.layer4.2.conv1.weight
-module.layer4.2.conv2.weight
-module.layer4.2.conv3.weight
-
-
-
-## Results
-
-Our model achieves the following performance on :
-
-### Image Classification on ImageNet
-
-#### Results in the paper
-
-
-
-
-#### Evaluation code output
-| Model name | Top 1 Accuracy | Top 5 Accuracy |
-| ------------------ |---------------- | -------------- |
-| 3-bit Shift ResNet-18 | 69.508% | 88.968% |
-| 3-bit Shift ResNet-50 | 75.748% | 92.800% |
-
-The minor accuracy difference (~0.3%) between Table 1 and the evaluation code output may cause by the difference between our implementation and the PyTorch official ImageNet training example.
\ No newline at end of file
diff --git a/S3-Training/dsconv2d.py b/S3-Training/dsconv2d.py
new file mode 100755
index 00000000..8de727f7
--- /dev/null
+++ b/S3-Training/dsconv2d.py
@@ -0,0 +1,73 @@
+'''
+Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.
+This program is free software; you can redistribute it and/or modify
+it under the terms of BSD 3-Clause License.
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+BSD 3-Clause License for more details.
+'''
+
+import torch
+import torch.nn as nn
+from torch.nn.parameter import Parameter
+import torch.nn.functional as F
+from torch.autograd import Function
+
+class STEBinarize01F(Function):
+ @staticmethod
+ def forward(ctx, inputs):
+ return (inputs.sign() - (inputs == 0).float() + 1) * 0.5
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None
+
+ste_binarize01 = STEBinarize01F.apply
+
+class STBinarizeF(Function):
+ @staticmethod
+ def forward(ctx, inputs):
+ return inputs.sign() + (inputs == 0).float()
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None
+
+ste_binarize = STBinarizeF.apply
+
+class DenseShiftConv2d3bit(nn.Conv2d):
+ """
+ 3bit DenseShift Conv2d module.
+ """
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
+ super(DenseShiftConv2d3bit, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
+ bias, padding_mode)
+
+ self.weight_sign = self.weight
+ self.weight_t1 = Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
+ self.weight_t2 = Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
+ self.weight_t3 = Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
+
+ def _conv_forward(self, input, weight):
+ if self.padding_mode != 'zeros':
+ return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
+ weight, self.bias, self.stride,
+ _pair(0), self.dilation, self.groups)
+ return F.conv2d(input, weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+ def forward(self, input):
+ shift_bits = ste_binarize01(self.weight_t1)
+ shift_bits = (shift_bits + 1) * ste_binarize01(self.weight_t2)
+ shift_bits = (shift_bits + 1) * ste_binarize01(self.weight_t3)
+ base = torch.ones_like(self.weight_sign) * 2
+
+ w_sign = ste_binarize(self.weight_sign)
+
+ bw_w_shift_3bit = w_sign * torch.sqrt(shift_bits + 1)
+ w_shift_3bit = bw_w_shift_3bit
+ with torch.no_grad():
+ fw_w_shift_3bit = w_sign * torch.pow(base, shift_bits)
+ w_shift_3bit += fw_w_shift_3bit - bw_w_shift_3bit
+
+ return self._conv_forward(input, w_shift_3bit)
\ No newline at end of file
diff --git a/S3-Training/figures/DenseShift3bit-Training.png b/S3-Training/figures/DenseShift3bit-Training.png
new file mode 100755
index 00000000..be61d22e
Binary files /dev/null and b/S3-Training/figures/DenseShift3bit-Training.png differ
diff --git a/S3-Training/main.py b/S3-Training/main.py
old mode 100644
new mode 100755
index e32c2406..e200f3aa
--- a/S3-Training/main.py
+++ b/S3-Training/main.py
@@ -1,323 +1,323 @@
-import argparse
-import os
-import random
-import shutil
-import time
-import warnings
-
-import torch
-import torch.nn as nn
-import torch.nn.parallel
-import torch.backends.cudnn as cudnn
-import torch.optim as optim
-import torch.utils.data
-import torch.utils.data.distributed
-import torchvision.transforms as transforms
-import torchvision.datasets as datasets
-
-from resnet import resnet
-from resnet import add_reg_sparse_to_loss
-
-parser = argparse.ArgumentParser(description='Sign-Sparse-Shift Reparameterization ImageNet Training')
-parser.add_argument('data', metavar='DIR',
- help='path to dataset')
-parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
- help='number of data loading workers (default: 4)')
-parser.add_argument('--epochs', default=200, type=int, metavar='N',
- help='number of total epochs to run')
-parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='manual epoch number (useful on restarts)')
-parser.add_argument('-b', '--batch-size', default=256, type=int,
- metavar='N',
- help='mini-batch size (default: 256), this is the total '
- 'batch size of all GPUs on the current node when '
- 'using Data Parallel or Distributed Data Parallel')
-parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
- metavar='LR', help='initial learning rate', dest='lr')
-parser.add_argument('--rs', '--reg-sparse', default=1e-5, type=float,
- metavar='RS', help='dense weight regularizer', dest='rs')
-parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
-parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
- metavar='W', help='weight decay (default: 1e-4)',
- dest='weight_decay')
-parser.add_argument('-p', '--print-freq', default=100, type=int,
- metavar='N', help='print frequency (default: 100)')
-parser.add_argument('--resume', default='', type=str, metavar='PATH',
- help='path to latest checkpoint (default: none)')
-parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
- help='evaluate model on validation set')
-parser.add_argument('--pretrained', dest='pretrained', action='store_true',
- help='use pre-trained model')
-parser.add_argument('--seed', default=None, type=int,
- help='seed for initializing training. ')
-
-best_acc1 = 0
-
-
-def main():
- global best_acc1
- args = parser.parse_args()
- if args.seed is not None:
- random.seed(args.seed)
- torch.manual_seed(args.seed)
- cudnn.deterministic = True
- warnings.warn('You have chosen to seed training. '
- 'This will turn on the CUDNN deterministic setting, '
- 'which can slow down your training considerably! '
- 'You may see unexpected behavior when restarting '
- 'from checkpoints.')
- # create model
- if args.pretrained:
- print("=> using pre-trained model '{}'".format('resnet18'))
- model = resnet(num_classes=1000, depth=18, dataset='imagenet', pretrained=True)
- else:
- print("=> creating model '{}'".format('resnet18'))
- model = resnet(num_classes=1000, depth=18, dataset='imagenet', pretrained=False)
- model.init_model()
-
- if not torch.cuda.is_available():
- print('using CPU, this will be slow')
- else:
- model = torch.nn.DataParallel(model).cuda()
-
- # define loss function (criterion) and optimizer
- criterion = nn.CrossEntropyLoss().cuda()
-
- optimizer = torch.optim.SGD(model.parameters(), args.lr,
- momentum=args.momentum,
- weight_decay=args.weight_decay)
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
-
- # optionally resume from a checkpoint
- if args.resume:
- if os.path.isfile(args.resume):
- print("=> loading checkpoint '{}'".format(args.resume))
- checkpoint = torch.load(args.resume)
- args.start_epoch = checkpoint['epoch']
- best_acc1 = checkpoint['best_acc1']
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- print("=> loaded checkpoint '{}' (epoch {})"
- .format(args.resume, checkpoint['epoch']))
- else:
- print("=> no checkpoint found at '{}'".format(args.resume))
-
- cudnn.benchmark = True
-
- # Data loading code
- traindir = os.path.join(args.data, 'train')
- valdir = os.path.join(args.data, 'val')
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])
-
- train_dataset = datasets.ImageFolder(
- traindir,
- transforms.Compose([
- transforms.RandomResizedCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- normalize,
- ]))
-
- train_loader = torch.utils.data.DataLoader(
- train_dataset, batch_size=args.batch_size, shuffle=True,
- num_workers=args.workers, pin_memory=True)
-
- val_loader = torch.utils.data.DataLoader(
- datasets.ImageFolder(valdir, transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- normalize,
- ])),
- batch_size=args.batch_size, shuffle=False,
- num_workers=args.workers, pin_memory=True)
-
- if args.evaluate:
- validate(val_loader, model, criterion, args)
- return
-
- for epoch in range(args.start_epoch, args.epochs):
- # adjust_learning_rate(optimizer, epoch, args)
- scheduler.step()
-
- # train for one epoch
- train(train_loader, model, criterion, optimizer, epoch, args)
-
- # evaluate on validation set
- acc1 = validate(val_loader, model, criterion, args)
-
- # remember best acc@1 and save checkpoint
- is_best = acc1 > best_acc1
- best_acc1 = max(acc1, best_acc1)
-
- save_checkpoint({
- 'epoch': epoch + 1,
- 'arch': 'resnet18',
- 'state_dict': model.state_dict(),
- 'best_acc1': best_acc1,
- 'optimizer': optimizer.state_dict(),
- }, is_best)
-
-
-def train(train_loader, model, criterion, optimizer, epoch, args):
- batch_time = AverageMeter('Time', ':6.3f')
- data_time = AverageMeter('Data', ':6.3f')
- losses = AverageMeter('Loss', ':.4e')
- top1 = AverageMeter('Acc@1', ':6.2f')
- top5 = AverageMeter('Acc@5', ':6.2f')
- progress = ProgressMeter(
- len(train_loader),
- [batch_time, data_time, losses, top1, top5],
- prefix="Epoch: [{}]".format(epoch))
-
- # switch to train mode
- model.train()
-
- end = time.time()
- for i, (images, target) in enumerate(train_loader):
- # measure data loading time
- data_time.update(time.time() - end)
-
- if torch.cuda.is_available():
- images = images.cuda(non_blocking=True)
- target = target.cuda(non_blocking=True)
-
- # compute output
- output = model(images)
- loss = criterion(output, target)
-
- # measure accuracy and record loss
- acc1, acc5 = accuracy(output, target, topk=(1, 5))
- losses.update(loss.item(), images.size(0))
- top1.update(acc1[0], images.size(0))
- top5.update(acc5[0], images.size(0))
-
- # add regularizer after loss measurement and before backprop
- loss = add_reg_sparse_to_loss(model, loss, alpha=args.rs)
-
- # compute gradient and do SGD step
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- if i % args.print_freq == 0:
- progress.display(i)
-
-
-def validate(val_loader, model, criterion, args):
- batch_time = AverageMeter('Time', ':6.3f')
- losses = AverageMeter('Loss', ':.4e')
- top1 = AverageMeter('Acc@1', ':6.2f')
- top5 = AverageMeter('Acc@5', ':6.2f')
- progress = ProgressMeter(
- len(val_loader),
- [batch_time, losses, top1, top5],
- prefix='Test: ')
-
- # switch to evaluate mode
- model.eval()
-
- with torch.no_grad():
- end = time.time()
- for i, (images, target) in enumerate(val_loader):
-
- if torch.cuda.is_available():
- images = images.cuda(non_blocking=True)
- target = target.cuda(non_blocking=True)
-
- # compute output
- output = model(images)
- loss = criterion(output, target)
-
- # measure accuracy and record loss
- acc1, acc5 = accuracy(output, target, topk=(1, 5))
- losses.update(loss.item(), images.size(0))
- top1.update(acc1[0], images.size(0))
- top5.update(acc5[0], images.size(0))
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- if i % args.print_freq == 0:
- progress.display(i)
-
- print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
- .format(top1=top1, top5=top5))
-
- return top1.avg
-
-
-def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
- torch.save(state, filename)
- if is_best:
- shutil.copyfile(filename, 'model_best.pth.tar')
-
-
-class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self, name, fmt=':f'):
- self.name = name
- self.fmt = fmt
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def __str__(self):
- fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
- return fmtstr.format(**self.__dict__)
-
-
-class ProgressMeter(object):
- def __init__(self, num_batches, meters, prefix=""):
- self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
- self.meters = meters
- self.prefix = prefix
-
- def display(self, batch):
- entries = [self.prefix + self.batch_fmtstr.format(batch)]
- entries += [str(meter) for meter in self.meters]
- print('\t'.join(entries))
-
- def _get_batch_fmtstr(self, num_batches):
- num_digits = len(str(num_batches // 1))
- fmt = '{:' + str(num_digits) + 'd}'
- return '[' + fmt + '/' + fmt.format(num_batches) + ']'
-
-
-def accuracy(output, target, topk=(1,)):
- """Computes the accuracy over the k top predictions for the specified values of k"""
- with torch.no_grad():
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- res = []
- for k in topk:
- correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
-
-
-if __name__ == '__main__':
- main()
+import argparse
+import os
+import random
+import shutil
+import time
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+
+from resnet import resnet
+
+model_names = ['resnet18', 'resnet50']
+
+parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
+parser.add_argument('data', metavar='DIR',
+ help='path to dataset')
+parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+ help='number of data loading workers (default: 4)')
+parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
+ choices=model_names,
+ help='model architecture: ' +
+ ' | '.join(model_names) +
+ ' (default: resnet18)')
+parser.add_argument('--epochs', default=200, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+ help='manual epoch number (useful on restarts)')
+parser.add_argument('-b', '--batch-size', default=256, type=int,
+ metavar='N',
+ help='mini-batch size (default: 256), this is the total '
+ 'batch size of all GPUs on the current node when '
+ 'using Data Parallel or Distributed Data Parallel')
+parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
+ metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+ help='momentum')
+parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+parser.add_argument('-p', '--print-freq', default=100, type=int,
+ metavar='N', help='print frequency (default: 100)')
+parser.add_argument('--resume', default='', type=str, metavar='PATH',
+ help='path to latest checkpoint (default: none)')
+parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
+ help='evaluate model on validation set')
+parser.add_argument('--seed', default=None, type=int,
+ help='seed for initializing training. ')
+
+best_acc1 = 0
+
+
+def main():
+ global best_acc1
+ args = parser.parse_args()
+ if args.seed is not None:
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ cudnn.deterministic = True
+ warnings.warn('You have chosen to seed training. '
+ 'This will turn on the CUDNN deterministic setting, '
+ 'which can slow down your training considerably! '
+ 'You may see unexpected behavior when restarting '
+ 'from checkpoints.')
+
+ resnet_depth = {
+ 'resnet18': 18,
+ 'resnet50': 50,
+ }
+ # create model
+ print("=> creating model '{}'".format('resnet18'))
+ model = resnet(num_classes=1000, depth=resnet_depth[args.arch], dataset='imagenet')
+ model.init_model()
+
+ if not torch.cuda.is_available():
+ print('using CPU, this will be slow')
+ else:
+ model = torch.nn.DataParallel(model).cuda()
+
+ # define loss function (criterion) and optimizer
+ criterion = nn.CrossEntropyLoss().cuda()
+
+ optimizer = torch.optim.SGD(model.parameters(), args.lr,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
+
+ # optionally resume from a checkpoint
+ if args.resume:
+ if os.path.isfile(args.resume):
+ print("=> loading checkpoint '{}'".format(args.resume))
+ checkpoint = torch.load(args.resume)
+ args.start_epoch = checkpoint['epoch']
+ best_acc1 = checkpoint['best_acc1']
+ model.load_state_dict(checkpoint['state_dict'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ print("=> loaded checkpoint '{}' (epoch {})"
+ .format(args.resume, checkpoint['epoch']))
+ else:
+ print("=> no checkpoint found at '{}'".format(args.resume))
+
+ cudnn.benchmark = True
+
+ # Data loading code
+ traindir = os.path.join(args.data, 'train')
+ valdir = os.path.join(args.data, 'val')
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+
+ train_dataset = datasets.ImageFolder(
+ traindir,
+ transforms.Compose([
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ normalize,
+ ]))
+
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=args.workers, pin_memory=True)
+
+ val_loader = torch.utils.data.DataLoader(
+ datasets.ImageFolder(valdir, transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ normalize,
+ ])),
+ batch_size=args.batch_size, shuffle=False,
+ num_workers=args.workers, pin_memory=True)
+
+ if args.evaluate:
+ validate(val_loader, model, criterion, args)
+ return
+
+ for epoch in range(args.start_epoch, args.epochs):
+ # adjust_learning_rate(optimizer, epoch, args)
+ scheduler.step()
+
+ # train for one epoch
+ train(train_loader, model, criterion, optimizer, epoch, args)
+
+ # evaluate on validation set
+ acc1 = validate(val_loader, model, criterion, args)
+
+ # remember best acc@1 and save checkpoint
+ is_best = acc1 > best_acc1
+ best_acc1 = max(acc1, best_acc1)
+
+ save_checkpoint({
+ 'epoch': epoch + 1,
+ 'arch': args.arch,
+ 'state_dict': model.state_dict(),
+ 'best_acc1': best_acc1,
+ 'optimizer': optimizer.state_dict(),
+ }, is_best)
+
+
+def train(train_loader, model, criterion, optimizer, epoch, args):
+ batch_time = AverageMeter('Time', ':6.3f')
+ data_time = AverageMeter('Data', ':6.3f')
+ losses = AverageMeter('Loss', ':.4e')
+ top1 = AverageMeter('Acc@1', ':6.2f')
+ top5 = AverageMeter('Acc@5', ':6.2f')
+ progress = ProgressMeter(
+ len(train_loader),
+ [batch_time, data_time, losses, top1, top5],
+ prefix="Epoch: [{}]".format(epoch))
+
+ # switch to train mode
+ model.train()
+
+ end = time.time()
+ for i, (images, target) in enumerate(train_loader):
+ # measure data loading time
+ data_time.update(time.time() - end)
+
+ if torch.cuda.is_available():
+ images = images.cuda(non_blocking=True)
+ target = target.cuda(non_blocking=True)
+
+ # compute output
+ output = model(images)
+ loss = criterion(output, target)
+
+ # measure accuracy and record loss
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+ losses.update(loss.item(), images.size(0))
+ top1.update(acc1[0], images.size(0))
+ top5.update(acc5[0], images.size(0))
+
+ # compute gradient and do SGD step
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if i % args.print_freq == 0:
+ progress.display(i)
+
+
+def validate(val_loader, model, criterion, args):
+ batch_time = AverageMeter('Time', ':6.3f')
+ losses = AverageMeter('Loss', ':.4e')
+ top1 = AverageMeter('Acc@1', ':6.2f')
+ top5 = AverageMeter('Acc@5', ':6.2f')
+ progress = ProgressMeter(
+ len(val_loader),
+ [batch_time, losses, top1, top5],
+ prefix='Test: ')
+
+ # switch to evaluate mode
+ model.eval()
+
+ with torch.no_grad():
+ end = time.time()
+ for i, (images, target) in enumerate(val_loader):
+
+ if torch.cuda.is_available():
+ images = images.cuda(non_blocking=True)
+ target = target.cuda(non_blocking=True)
+
+ # compute output
+ output = model(images)
+ loss = criterion(output, target)
+
+ # measure accuracy and record loss
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+ losses.update(loss.item(), images.size(0))
+ top1.update(acc1[0], images.size(0))
+ top5.update(acc5[0], images.size(0))
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if i % args.print_freq == 0:
+ progress.display(i)
+
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
+ .format(top1=top1, top5=top5))
+
+ return top1.avg
+
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+ torch.save(state, filename)
+ if is_best:
+ shutil.copyfile(filename, 'model_best.pth.tar')
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self, name, fmt=':f'):
+ self.name = name
+ self.fmt = fmt
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
+
+
+class ProgressMeter(object):
+ def __init__(self, num_batches, meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+ self.meters = meters
+ self.prefix = prefix
+
+ def display(self, batch):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(meter) for meter in self.meters]
+ print('\t'.join(entries))
+
+ def _get_batch_fmtstr(self, num_batches):
+ num_digits = len(str(num_batches // 1))
+ fmt = '{:' + str(num_digits) + 'd}'
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ with torch.no_grad():
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+if __name__ == '__main__':
+ main()
diff --git a/S3-Training/resnet.py b/S3-Training/resnet.py
old mode 100644
new mode 100755
index f481a1f9..294c312e
--- a/S3-Training/resnet.py
+++ b/S3-Training/resnet.py
@@ -2,16 +2,11 @@
import torchvision.transforms as transforms
import math
-from s3conv2dshift import S3Conv2dShift3bit, add_reg_sparse_to_loss
+from dsconv2d import DenseShiftConv2d3bit
-def conv3x3_fp(in_planes, out_planes, stride=1):
+def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
-
-def conv3x3_s3(in_planes, out_planes, stride=1):
- "3x3 convolution with padding"
- return S3Conv2dShift3bit(in_planes, out_planes, kernel_size=3, stride=stride,
+ return DenseShiftConv2d3bit(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
@@ -20,10 +15,10 @@ class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
- self.conv1 = conv3x3_s3(inplanes, planes, stride)
+ self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
- self.conv2 = conv3x3_s3(planes, planes)
+ self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
@@ -53,7 +48,7 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
- conv2d = S3Conv2dShift3bit
+ conv2d = DenseShiftConv2d3bit
self.conv1 = conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
@@ -99,7 +94,7 @@ def _make_layer(self, block, planes, blocks, stride=1):
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- S3Conv2dShift3bit(self.inplanes, planes * block.expansion,
+ DenseShiftConv2d3bit(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
@@ -175,12 +170,12 @@ def init_model(self):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
- if isinstance(m, S3Conv2dShift3bit):
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
- m.weight_val.data.normal_(0, math.sqrt(2. / n))
- m.weight_shift.data.normal_(0, math.sqrt(2. / n))
- m.weight_shift2.data.normal_(0, math.sqrt(2. / n))
+ # Low-variance initialization
+ if isinstance(m, DenseShiftConv2d3bit):
+ m.weight_sign.data.normal_(0, 1e-3)
+ m.weight_t1.data.normal_(0, 1e-3)
+ m.weight_t2.data.normal_(0, 1e-3)
+ m.weight_t3.data.normal_(0, 1e-3)
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
@@ -212,12 +207,12 @@ def init_model(self):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
- if isinstance(m, S3Conv2dShift3bit):
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
- m.weight_val.data.normal_(0, math.sqrt(2. / n))
- m.weight_shift.data.normal_(0, math.sqrt(2. / n))
- m.weight_shift2.data.normal_(0, math.sqrt(2. / n))
+ # Low-variance initialization
+ if isinstance(m, DenseShiftConv2d3bit):
+ m.weight_sign.data.normal_(0, 1e-3)
+ m.weight_t1.data.normal_(0, 1e-3)
+ m.weight_t2.data.normal_(0, 1e-3)
+ m.weight_t3.data.normal_(0, 1e-3)
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
diff --git a/S3-Training/s3-training-neurips-2021/README.md b/S3-Training/s3-training-neurips-2021/README.md
new file mode 100644
index 00000000..8ebfa6e0
--- /dev/null
+++ b/S3-Training/s3-training-neurips-2021/README.md
@@ -0,0 +1,223 @@
+# S3: Sign-Sparse-Shift Reparametrization for Effective Training of Low-bit Shift Networks
+
+This repository is the DEMO code of the NeurIPS 2021 paper [S3: Sign-Sparse-Shift Reparametrization for Effective Training of Low-bit Shift Networks](https://proceedings.neurips.cc/paper/2021/file/7a1d9028a78f418cb8f01909a348d9b2-Paper.pdf).
+
+Shift neural networks (Power-of-Two quantization) reduce computation complexity by removing expensive multiplication operations and quantizing continuous weights into low-bit discrete values, which are fast and energy-efficient compared to conventional neural networks. However, existing shift networks are sensitive to the weight initialization and yield a degraded performance caused by vanishing gradient and weight sign freezing problem. To address these issues, we propose S3 re-parameterization, a novel technique for training low-bit shift networks. Our method decomposes a discrete parameter in a sign-sparse-shift 3-fold manner. This way, it efficiently learns a low-bit network with weight dynamics similar to full-precision networks and insensitive to weight initialization. Our proposed training method pushes the boundaries of shift neural networks and shows 3-bit shift networks compete with their full-precision counterparts in terms of top-1 accuracy on ImageNet.
+
+
+
+
+
+## Requirements
+
+- Install PyTorch ([pytorch.org](http://pytorch.org))
+- Download [PyTorch official ImageNet training example](https://github.com/pytorch/examples/tree/master/imagenet).
+ - `wget https://raw.githubusercontent.com/pytorch/examples/master/imagenet/main.py`
+ - `wget https://raw.githubusercontent.com/pytorch/examples/master/imagenet/requirements.txt`
+- `pip install -r requirements.txt`
+- Download the ImageNet dataset from http://www.image-net.org/
+ - Then, and move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh)
+
+## Training from scratch
+
+### 3-bit Shift Network on ResNet-18 ImageNet
+
+To train a 3bit S3 re-parameterized shift network with ResNet18 on ImageNet from scratch, run:
+```train
+python main.py /path/to/imagenet
+```
+
+## Pre-trained checkpoints
+
+Two pre-trained checkpoints corresponding to the 3-bit results reported in the paper can be [downloaded here](./pre-trained-ckpt-evaluation/). The checkpoints convert into a format compatible with the PyTorch official ImageNet training example so that the standard implementation code can evaluate the validation accuracy of the checkpoints.
+
+### 3-bit Shift Network on ResNet-18 ImageNet
+
+To evaluate the pre-trained checkpoint of 3bit S3 re-parameterized shift network with ResNet-18 on ImageNet, run:
+
+```eval
+python main_eval.py --evaluate --resume s3-3bit-resnet18-pytorch-imagenet.pth.tar --arch resnet18 /path/to/imagenet
+```
+
+Outputs:
+```example_output
+=> creating model 'resnet18'
+=> loading checkpoint 's3-3bit-resnet18-pytorch-imagenet.pth.tar'
+=> loaded checkpoint 's3-3bit-resnet18-pytorch-imagenet.pth.tar' (epoch 199)
+Test: [ 0/196] Time 4.506 ( 4.506) Loss 6.7311e-01 (6.7311e-01) Acc@1 82.81 ( 82.81) Acc@5 96.09 ( 96.09)
+Test: [ 10/196] Time 0.072 ( 0.986) Loss 1.2426e+00 (9.1206e-01) Acc@1 69.14 ( 77.73) Acc@5 88.67 ( 92.68)
+Test: [ 20/196] Time 2.218 ( 0.904) Loss 8.9948e-01 (9.1190e-01) Acc@1 81.64 ( 77.40) Acc@5 91.02 ( 92.47)
+Test: [ 30/196] Time 0.072 ( 0.809) Loss 8.3274e-01 (8.8309e-01) Acc@1 80.47 ( 77.95) Acc@5 94.92 ( 92.97)
+Test: [ 40/196] Time 2.417 ( 0.815) Loss 9.7517e-01 (9.2135e-01) Acc@1 75.78 ( 76.59) Acc@5 94.14 ( 93.16)
+Test: [ 50/196] Time 0.072 ( 0.775) Loss 6.4144e-01 (9.1274e-01) Acc@1 83.59 ( 76.49) Acc@5 96.88 ( 93.35)
+Test: [ 60/196] Time 2.033 ( 0.795) Loss 1.1415e+00 (9.1855e-01) Acc@1 70.31 ( 76.14) Acc@5 91.02 ( 93.45)
+Test: [ 70/196] Time 0.072 ( 0.791) Loss 9.0677e-01 (9.0656e-01) Acc@1 73.83 ( 76.42) Acc@5 94.14 ( 93.54)
+Test: [ 80/196] Time 0.848 ( 0.784) Loss 1.7171e+00 (9.2979e-01) Acc@1 57.03 ( 75.85) Acc@5 85.55 ( 93.26)
+Test: [ 90/196] Time 1.443 ( 0.785) Loss 2.2276e+00 (9.9383e-01) Acc@1 51.56 ( 74.53) Acc@5 75.78 ( 92.41)
+Test: [100/196] Time 1.244 ( 0.767) Loss 1.7705e+00 (1.0593e+00) Acc@1 55.47 ( 73.16) Acc@5 82.03 ( 91.58)
+Test: [110/196] Time 0.449 ( 0.771) Loss 1.2247e+00 (1.0864e+00) Acc@1 70.70 ( 72.59) Acc@5 89.84 ( 91.17)
+Test: [120/196] Time 0.074 ( 0.763) Loss 1.9402e+00 (1.1115e+00) Acc@1 55.86 ( 72.25) Acc@5 76.95 ( 90.76)
+Test: [130/196] Time 0.071 ( 0.774) Loss 1.0368e+00 (1.1486e+00) Acc@1 74.61 ( 71.40) Acc@5 92.97 ( 90.29)
+Test: [140/196] Time 0.072 ( 0.754) Loss 1.4686e+00 (1.1709e+00) Acc@1 65.23 ( 70.97) Acc@5 83.98 ( 90.00)
+Test: [150/196] Time 0.073 ( 0.763) Loss 1.4905e+00 (1.1954e+00) Acc@1 69.92 ( 70.51) Acc@5 85.16 ( 89.62)
+Test: [160/196] Time 0.073 ( 0.754) Loss 1.1636e+00 (1.2138e+00) Acc@1 73.44 ( 70.19) Acc@5 89.84 ( 89.35)
+Test: [170/196] Time 0.072 ( 0.755) Loss 7.5062e-01 (1.2348e+00) Acc@1 77.73 ( 69.75) Acc@5 96.48 ( 89.05)
+Test: [180/196] Time 0.073 ( 0.745) Loss 1.3958e+00 (1.2521e+00) Acc@1 64.06 ( 69.37) Acc@5 88.67 ( 88.84)
+Test: [190/196] Time 0.072 ( 0.748) Loss 1.2849e+00 (1.2503e+00) Acc@1 64.84 ( 69.33) Acc@5 92.58 ( 88.89)
+ * Acc@1 69.508 Acc@5 88.968
+```
+
+The elements of following weight tensors in the checkpoint are restricted to the discrete weight values of 3-bit shift network {-4, -2, -1, 0, 1, 2, 4}
+
+ Quantized tensor name in ResNet-18 checkpoint
+module.layer1.0.conv1.weight
+module.layer1.0.conv2.weight
+module.layer1.1.conv1.weight
+module.layer1.1.conv2.weight
+module.layer2.0.conv1.weight
+module.layer2.0.conv2.weight
+module.layer2.0.downsample.0.weight
+module.layer2.1.conv1.weight
+module.layer2.1.conv2.weight
+module.layer3.0.conv1.weight
+module.layer3.0.conv2.weight
+module.layer3.0.downsample.0.weight
+module.layer3.1.conv1.weight
+module.layer3.1.conv2.weight
+module.layer4.0.conv1.weight
+module.layer4.0.conv2.weight
+module.layer4.0.downsample.0.weight
+module.layer4.1.conv1.weight
+module.layer4.1.conv2.weight
+
+
+The following code snippet can load a discrete weight tensor from the checkpoint and output the unique discrete values in this tensor.
+```eval
+import torch
+TENSOR_NAME = "module.layer1.0.conv1.weight"
+CKPT_NAME = "s3-3bit-resnet18-pytorch-imagenet.pth.tar"
+
+checkpoint = torch.load(CKPT_NAME)
+model_state_dict = checkpoint['state_dict']
+discrete_weight = model_state_dict[TENSOR_NAME]
+print(torch.unique(discrete_weight))
+```
+
+Outputs:
+```example_output
+tensor([-4., -2., -1., -0., 1., 2., 4.], device='cuda:0')
+```
+
+
+### 3-bit Shift Network on ResNet-50 ImageNet
+
+To evaluate the pre-trained checkpoint of 3bit S3 re-parameterized shift network with ResNet-50 on ImageNet, run:
+
+```eval
+python main_eval.py --evaluate --resume s3-3bit-resnet50-pytorch-imagenet.pth.tar --arch resnet50 /path/to/imagenet
+```
+
+Outputs:
+```example_output
+=> creating model 'resnet50'
+=> loading checkpoint 's3-3bit-resnet50-pytorch-imagenet.pth.tar'
+=> loaded checkpoint 's3-3bit-resnet50-pytorch-imagenet.pth.tar' (epoch 199)
+Test: [ 0/196] Time 4.976 ( 4.976) Loss 4.9636e-01 (4.9636e-01) Acc@1 86.33 ( 86.33) Acc@5 97.27 ( 97.27)
+Test: [ 10/196] Time 0.221 ( 0.972) Loss 1.0587e+00 (6.8706e-01) Acc@1 75.39 ( 82.07) Acc@5 92.19 ( 95.63)
+Test: [ 20/196] Time 1.160 ( 0.907) Loss 7.0471e-01 (6.8882e-01) Acc@1 86.33 ( 81.99) Acc@5 92.58 ( 95.48)
+Test: [ 30/196] Time 0.221 ( 0.873) Loss 8.0941e-01 (6.5377e-01) Acc@1 78.91 ( 83.09) Acc@5 94.92 ( 95.60)
+Test: [ 40/196] Time 2.344 ( 0.906) Loss 6.5837e-01 (6.8861e-01) Acc@1 82.03 ( 81.85) Acc@5 96.88 ( 95.61)
+Test: [ 50/196] Time 0.223 ( 0.829) Loss 4.6707e-01 (6.8241e-01) Acc@1 88.67 ( 81.78) Acc@5 96.88 ( 95.76)
+Test: [ 60/196] Time 1.323 ( 0.812) Loss 8.7407e-01 (6.9512e-01) Acc@1 74.22 ( 81.40) Acc@5 96.48 ( 95.87)
+Test: [ 70/196] Time 2.609 ( 0.832) Loss 7.4790e-01 (6.8027e-01) Acc@1 76.95 ( 81.63) Acc@5 96.88 ( 96.06)
+Test: [ 80/196] Time 0.221 ( 0.810) Loss 1.4313e+00 (7.0608e-01) Acc@1 65.23 ( 81.13) Acc@5 87.11 ( 95.75)
+Test: [ 90/196] Time 3.314 ( 0.842) Loss 1.8285e+00 (7.5399e-01) Acc@1 58.20 ( 80.08) Acc@5 85.94 ( 95.25)
+Test: [100/196] Time 0.219 ( 0.825) Loss 1.2244e+00 (8.0642e-01) Acc@1 66.80 ( 78.93) Acc@5 89.84 ( 94.59)
+Test: [110/196] Time 3.015 ( 0.847) Loss 8.3800e-01 (8.3314e-01) Acc@1 78.91 ( 78.41) Acc@5 94.92 ( 94.27)
+Test: [120/196] Time 0.219 ( 0.844) Loss 1.2821e+00 (8.4899e-01) Acc@1 71.48 ( 78.15) Acc@5 88.28 ( 94.02)
+Test: [130/196] Time 2.935 ( 0.857) Loss 6.7108e-01 (8.8153e-01) Acc@1 81.64 ( 77.40) Acc@5 95.31 ( 93.68)
+Test: [140/196] Time 0.222 ( 0.852) Loss 1.1377e+00 (8.9882e-01) Acc@1 72.27 ( 77.09) Acc@5 91.80 ( 93.49)
+Test: [150/196] Time 2.446 ( 0.858) Loss 1.1069e+00 (9.1730e-01) Acc@1 76.17 ( 76.76) Acc@5 90.62 ( 93.22)
+Test: [160/196] Time 0.220 ( 0.847) Loss 7.7915e-01 (9.3251e-01) Acc@1 83.20 ( 76.46) Acc@5 93.36 ( 93.00)
+Test: [170/196] Time 2.340 ( 0.852) Loss 5.5731e-01 (9.4940e-01) Acc@1 84.77 ( 76.01) Acc@5 96.88 ( 92.81)
+Test: [180/196] Time 0.221 ( 0.845) Loss 1.2214e+00 (9.6362e-01) Acc@1 67.97 ( 75.69) Acc@5 93.75 ( 92.70)
+Test: [190/196] Time 2.750 ( 0.848) Loss 1.1438e+00 (9.6272e-01) Acc@1 69.92 ( 75.63) Acc@5 94.53 ( 92.75)
+ * Acc@1 75.748 Acc@5 92.800
+```
+
+The elements of following weight tensors in the checkpoint are restricted to the discrete weight values of 3-bit shift network {-4, -2, -1, 0, 1, 2, 4}
+
+ Quantized tensor name in ResNet-50 checkpoint
+module.layer1.0.conv1.weight
+module.layer1.0.conv2.weight
+module.layer1.0.conv3.weight
+module.layer1.0.downsample.0.weight
+module.layer1.1.conv1.weight
+module.layer1.1.conv2.weight
+module.layer1.1.conv3.weight
+module.layer1.2.conv1.weight
+module.layer1.2.conv2.weight
+module.layer1.2.conv3.weight
+module.layer2.0.conv1.weight
+module.layer2.0.conv2.weight
+module.layer2.0.conv3.weight
+module.layer2.0.downsample.0.weight
+module.layer2.1.conv1.weight
+module.layer2.1.conv2.weight
+module.layer2.1.conv3.weight
+module.layer2.2.conv1.weight
+module.layer2.2.conv2.weight
+module.layer2.2.conv3.weight
+module.layer2.3.conv1.weight
+module.layer2.3.conv2.weight
+module.layer2.3.conv3.weight
+module.layer3.0.conv1.weight
+module.layer3.0.conv2.weight
+module.layer3.0.conv3.weight
+module.layer3.0.downsample.0.weight
+module.layer3.1.conv1.weight
+module.layer3.1.conv2.weight
+module.layer3.1.conv3.weight
+module.layer3.2.conv1.weight
+module.layer3.2.conv2.weight
+module.layer3.2.conv3.weight
+module.layer3.3.conv1.weight
+module.layer3.3.conv2.weight
+module.layer3.3.conv3.weight
+module.layer3.4.conv1.weight
+module.layer3.4.conv2.weight
+module.layer3.4.conv3.weight
+module.layer3.5.conv1.weight
+module.layer3.5.conv2.weight
+module.layer3.5.conv3.weight
+module.layer4.0.conv1.weight
+module.layer4.0.conv2.weight
+module.layer4.0.conv3.weight
+module.layer4.0.downsample.0.weight
+module.layer4.1.conv1.weight
+module.layer4.1.conv2.weight
+module.layer4.1.conv3.weight
+module.layer4.2.conv1.weight
+module.layer4.2.conv2.weight
+module.layer4.2.conv3.weight
+
+
+
+## Results
+
+Our model achieves the following performance on :
+
+### Image Classification on ImageNet
+
+#### Results in the paper
+
+
+
+
+#### Evaluation code output
+| Model name | Top 1 Accuracy | Top 5 Accuracy |
+| ------------------ |---------------- | -------------- |
+| 3-bit Shift ResNet-18 | 69.508% | 88.968% |
+| 3-bit Shift ResNet-50 | 75.748% | 92.800% |
+
+The minor accuracy difference (~0.3%) between Table 1 and the evaluation code output may cause by the difference between our implementation and the PyTorch official ImageNet training example.
\ No newline at end of file
diff --git a/S3-Training/figures/S3-Shift3bit-Training.png b/S3-Training/s3-training-neurips-2021/figures/S3-Shift3bit-Training.png
similarity index 100%
rename from S3-Training/figures/S3-Shift3bit-Training.png
rename to S3-Training/s3-training-neurips-2021/figures/S3-Shift3bit-Training.png
diff --git a/S3-Training/figures/tables2.png b/S3-Training/s3-training-neurips-2021/figures/tables2.png
similarity index 100%
rename from S3-Training/figures/tables2.png
rename to S3-Training/s3-training-neurips-2021/figures/tables2.png
diff --git a/S3-Training/s3-training-neurips-2021/main.py b/S3-Training/s3-training-neurips-2021/main.py
new file mode 100644
index 00000000..e32c2406
--- /dev/null
+++ b/S3-Training/s3-training-neurips-2021/main.py
@@ -0,0 +1,323 @@
+import argparse
+import os
+import random
+import shutil
+import time
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+
+from resnet import resnet
+from resnet import add_reg_sparse_to_loss
+
+parser = argparse.ArgumentParser(description='Sign-Sparse-Shift Reparameterization ImageNet Training')
+parser.add_argument('data', metavar='DIR',
+ help='path to dataset')
+parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+ help='number of data loading workers (default: 4)')
+parser.add_argument('--epochs', default=200, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+ help='manual epoch number (useful on restarts)')
+parser.add_argument('-b', '--batch-size', default=256, type=int,
+ metavar='N',
+ help='mini-batch size (default: 256), this is the total '
+ 'batch size of all GPUs on the current node when '
+ 'using Data Parallel or Distributed Data Parallel')
+parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
+ metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--rs', '--reg-sparse', default=1e-5, type=float,
+ metavar='RS', help='dense weight regularizer', dest='rs')
+parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+ help='momentum')
+parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+parser.add_argument('-p', '--print-freq', default=100, type=int,
+ metavar='N', help='print frequency (default: 100)')
+parser.add_argument('--resume', default='', type=str, metavar='PATH',
+ help='path to latest checkpoint (default: none)')
+parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
+ help='evaluate model on validation set')
+parser.add_argument('--pretrained', dest='pretrained', action='store_true',
+ help='use pre-trained model')
+parser.add_argument('--seed', default=None, type=int,
+ help='seed for initializing training. ')
+
+best_acc1 = 0
+
+
+def main():
+ global best_acc1
+ args = parser.parse_args()
+ if args.seed is not None:
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ cudnn.deterministic = True
+ warnings.warn('You have chosen to seed training. '
+ 'This will turn on the CUDNN deterministic setting, '
+ 'which can slow down your training considerably! '
+ 'You may see unexpected behavior when restarting '
+ 'from checkpoints.')
+ # create model
+ if args.pretrained:
+ print("=> using pre-trained model '{}'".format('resnet18'))
+ model = resnet(num_classes=1000, depth=18, dataset='imagenet', pretrained=True)
+ else:
+ print("=> creating model '{}'".format('resnet18'))
+ model = resnet(num_classes=1000, depth=18, dataset='imagenet', pretrained=False)
+ model.init_model()
+
+ if not torch.cuda.is_available():
+ print('using CPU, this will be slow')
+ else:
+ model = torch.nn.DataParallel(model).cuda()
+
+ # define loss function (criterion) and optimizer
+ criterion = nn.CrossEntropyLoss().cuda()
+
+ optimizer = torch.optim.SGD(model.parameters(), args.lr,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
+
+ # optionally resume from a checkpoint
+ if args.resume:
+ if os.path.isfile(args.resume):
+ print("=> loading checkpoint '{}'".format(args.resume))
+ checkpoint = torch.load(args.resume)
+ args.start_epoch = checkpoint['epoch']
+ best_acc1 = checkpoint['best_acc1']
+ model.load_state_dict(checkpoint['state_dict'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ print("=> loaded checkpoint '{}' (epoch {})"
+ .format(args.resume, checkpoint['epoch']))
+ else:
+ print("=> no checkpoint found at '{}'".format(args.resume))
+
+ cudnn.benchmark = True
+
+ # Data loading code
+ traindir = os.path.join(args.data, 'train')
+ valdir = os.path.join(args.data, 'val')
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+
+ train_dataset = datasets.ImageFolder(
+ traindir,
+ transforms.Compose([
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ normalize,
+ ]))
+
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=args.workers, pin_memory=True)
+
+ val_loader = torch.utils.data.DataLoader(
+ datasets.ImageFolder(valdir, transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ normalize,
+ ])),
+ batch_size=args.batch_size, shuffle=False,
+ num_workers=args.workers, pin_memory=True)
+
+ if args.evaluate:
+ validate(val_loader, model, criterion, args)
+ return
+
+ for epoch in range(args.start_epoch, args.epochs):
+ # adjust_learning_rate(optimizer, epoch, args)
+ scheduler.step()
+
+ # train for one epoch
+ train(train_loader, model, criterion, optimizer, epoch, args)
+
+ # evaluate on validation set
+ acc1 = validate(val_loader, model, criterion, args)
+
+ # remember best acc@1 and save checkpoint
+ is_best = acc1 > best_acc1
+ best_acc1 = max(acc1, best_acc1)
+
+ save_checkpoint({
+ 'epoch': epoch + 1,
+ 'arch': 'resnet18',
+ 'state_dict': model.state_dict(),
+ 'best_acc1': best_acc1,
+ 'optimizer': optimizer.state_dict(),
+ }, is_best)
+
+
+def train(train_loader, model, criterion, optimizer, epoch, args):
+ batch_time = AverageMeter('Time', ':6.3f')
+ data_time = AverageMeter('Data', ':6.3f')
+ losses = AverageMeter('Loss', ':.4e')
+ top1 = AverageMeter('Acc@1', ':6.2f')
+ top5 = AverageMeter('Acc@5', ':6.2f')
+ progress = ProgressMeter(
+ len(train_loader),
+ [batch_time, data_time, losses, top1, top5],
+ prefix="Epoch: [{}]".format(epoch))
+
+ # switch to train mode
+ model.train()
+
+ end = time.time()
+ for i, (images, target) in enumerate(train_loader):
+ # measure data loading time
+ data_time.update(time.time() - end)
+
+ if torch.cuda.is_available():
+ images = images.cuda(non_blocking=True)
+ target = target.cuda(non_blocking=True)
+
+ # compute output
+ output = model(images)
+ loss = criterion(output, target)
+
+ # measure accuracy and record loss
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+ losses.update(loss.item(), images.size(0))
+ top1.update(acc1[0], images.size(0))
+ top5.update(acc5[0], images.size(0))
+
+ # add regularizer after loss measurement and before backprop
+ loss = add_reg_sparse_to_loss(model, loss, alpha=args.rs)
+
+ # compute gradient and do SGD step
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if i % args.print_freq == 0:
+ progress.display(i)
+
+
+def validate(val_loader, model, criterion, args):
+ batch_time = AverageMeter('Time', ':6.3f')
+ losses = AverageMeter('Loss', ':.4e')
+ top1 = AverageMeter('Acc@1', ':6.2f')
+ top5 = AverageMeter('Acc@5', ':6.2f')
+ progress = ProgressMeter(
+ len(val_loader),
+ [batch_time, losses, top1, top5],
+ prefix='Test: ')
+
+ # switch to evaluate mode
+ model.eval()
+
+ with torch.no_grad():
+ end = time.time()
+ for i, (images, target) in enumerate(val_loader):
+
+ if torch.cuda.is_available():
+ images = images.cuda(non_blocking=True)
+ target = target.cuda(non_blocking=True)
+
+ # compute output
+ output = model(images)
+ loss = criterion(output, target)
+
+ # measure accuracy and record loss
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+ losses.update(loss.item(), images.size(0))
+ top1.update(acc1[0], images.size(0))
+ top5.update(acc5[0], images.size(0))
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if i % args.print_freq == 0:
+ progress.display(i)
+
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
+ .format(top1=top1, top5=top5))
+
+ return top1.avg
+
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+ torch.save(state, filename)
+ if is_best:
+ shutil.copyfile(filename, 'model_best.pth.tar')
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self, name, fmt=':f'):
+ self.name = name
+ self.fmt = fmt
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
+
+
+class ProgressMeter(object):
+ def __init__(self, num_batches, meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+ self.meters = meters
+ self.prefix = prefix
+
+ def display(self, batch):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(meter) for meter in self.meters]
+ print('\t'.join(entries))
+
+ def _get_batch_fmtstr(self, num_batches):
+ num_digits = len(str(num_batches // 1))
+ fmt = '{:' + str(num_digits) + 'd}'
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ with torch.no_grad():
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+if __name__ == '__main__':
+ main()
diff --git a/S3-Training/pre-trained-ckpt-evaluation/main_eval.py b/S3-Training/s3-training-neurips-2021/pre-trained-ckpt-evaluation/main_eval.py
similarity index 100%
rename from S3-Training/pre-trained-ckpt-evaluation/main_eval.py
rename to S3-Training/s3-training-neurips-2021/pre-trained-ckpt-evaluation/main_eval.py
diff --git a/S3-Training/pre-trained-ckpt-evaluation/pre-trained-ckpts.7z b/S3-Training/s3-training-neurips-2021/pre-trained-ckpt-evaluation/pre-trained-ckpts.7z
similarity index 100%
rename from S3-Training/pre-trained-ckpt-evaluation/pre-trained-ckpts.7z
rename to S3-Training/s3-training-neurips-2021/pre-trained-ckpt-evaluation/pre-trained-ckpts.7z
diff --git a/S3-Training/s3-training-neurips-2021/resnet.py b/S3-Training/s3-training-neurips-2021/resnet.py
new file mode 100644
index 00000000..f481a1f9
--- /dev/null
+++ b/S3-Training/s3-training-neurips-2021/resnet.py
@@ -0,0 +1,239 @@
+import torch.nn as nn
+import torchvision.transforms as transforms
+import math
+
+from s3conv2dshift import S3Conv2dShift3bit, add_reg_sparse_to_loss
+
+def conv3x3_fp(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+def conv3x3_s3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return S3Conv2dShift3bit(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+
+ self.conv1 = conv3x3_s3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3_s3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+
+ conv2d = S3Conv2dShift3bit
+
+ self.conv1 = conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self):
+ super(ResNet, self).__init__()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ S3Conv2dShift3bit(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+
+class ResNet_imagenet(ResNet):
+ def __init__(self, num_classes=1000, depth=18):
+
+ super(ResNet_imagenet, self).__init__()
+
+ block = None
+ layers = []
+ num_classes = num_classes or 1000
+ depth = depth or 50
+ if depth == 18:
+ block = BasicBlock
+ layers = [2, 2, 2, 2]
+ if depth == 34:
+ block = BasicBlock
+ layers = [3, 4, 6, 3]
+ if depth == 50:
+ block = Bottleneck
+ layers = [3, 4, 6, 3]
+ if depth == 101:
+ block = Bottleneck
+ layers = [3, 4, 23, 3]
+ if depth == 152:
+ block = Bottleneck
+ layers = [3, 8, 36, 3]
+
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AvgPool2d(7)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ def init_model(self):
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+
+ if isinstance(m, S3Conv2dShift3bit):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ m.weight_val.data.normal_(0, math.sqrt(2. / n))
+ m.weight_shift.data.normal_(0, math.sqrt(2. / n))
+ m.weight_shift2.data.normal_(0, math.sqrt(2. / n))
+
+ if isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+class ResNet_cifar10(ResNet):
+
+ def __init__(self, num_classes=10,
+ block=BasicBlock, depth=18):
+ super(ResNet_cifar10, self).__init__()
+ self.inplanes = 16
+ n = int((depth - 2) / 6)
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(16)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = lambda x: x
+ self.layer1 = self._make_layer(block, 16, n)
+ self.layer2 = self._make_layer(block, 32, n, stride=2)
+ self.layer3 = self._make_layer(block, 64, n, stride=2)
+ self.layer4 = lambda x: x
+ self.avgpool = nn.AvgPool2d(8)
+ self.fc = nn.Linear(64, num_classes)
+
+ def init_model(self):
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+
+ if isinstance(m, S3Conv2dShift3bit):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ m.weight_val.data.normal_(0, math.sqrt(2. / n))
+ m.weight_shift.data.normal_(0, math.sqrt(2. / n))
+ m.weight_shift2.data.normal_(0, math.sqrt(2. / n))
+
+ if isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+
+def resnet(**kwargs):
+ num_classes, depth, dataset, pretrained = map(
+ kwargs.get, ['num_classes', 'depth', 'dataset', 'pretrained'])
+
+ if dataset == 'imagenet':
+ num_classes = num_classes or 1000
+ depth = depth or 50
+ return ResNet_imagenet(num_classes=num_classes, depth=depth)
+
+ elif dataset == 'cifar10':
+ num_classes = num_classes or 10
+ depth = depth or 18
+ return ResNet_cifar10(num_classes=num_classes, block=BasicBlock, depth=depth)
diff --git a/S3-Training/s3conv2dshift.py b/S3-Training/s3-training-neurips-2021/s3conv2dshift.py
similarity index 100%
rename from S3-Training/s3conv2dshift.py
rename to S3-Training/s3-training-neurips-2021/s3conv2dshift.py