Skip to content

Commit 3e3fd53

Browse files
committed
Added new folder for new method synthetic_info_bottleneck
1 parent e34e8b7 commit 3e3fd53

19 files changed

+2324
-0
lines changed

synthetic_info_bottleneck/README.md

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# \[ICLR 2020\] Synthetic information bottleneck for transductive meta-learning
2+
This repo contains the implementation of the *synthetic information bottleneck* algorithm for few-shot classification on Mini-ImageNet,
3+
which is used in our ICLR 2020 paper
4+
[Empirical Bayes Transductive Meta-Learning with Synthetic Gradients](https://openreview.net/forum?id=Hkg-xgrYvH).
5+
6+
If our code is helpful for your research, please consider citing:
7+
``` Bash
8+
@inproceedings{
9+
Hu2020Empirical,
10+
title={Empirical Bayes Transductive Meta-Learning with Synthetic Gradients},
11+
author={Shell Xu Hu and Pablo Garcia Moreno and Yang Xiao and Xi Shen and Guillaume Obozinski and Neil Lawrence and Andreas Damianou},
12+
booktitle={International Conference on Learning Representations (ICLR)},
13+
year={2020},
14+
url={https://openreview.net/forum?id=Hkg-xgrYvH}
15+
}
16+
```
17+
18+
## Authors of the code
19+
[Shell Xu Hu](http://hushell.github.io/), [Xi Shen](https://xishen0220.github.io/) and [Yang Xiao](https://youngxiao13.github.io/)
20+
21+
22+
## Dependencies
23+
The code is tested under **Pytorch > 1.0 + Python 3.6** environment with extra packages:
24+
``` Bash
25+
pip install -r requirements.txt
26+
```
27+
28+
29+
## How to use the code on Mini-ImageNet?
30+
### **Step 0**: Download Mini-ImageNet dataset
31+
32+
``` Bash
33+
cd data
34+
bash download_miniimagenet.sh
35+
cd ..
36+
```
37+
38+
### **Step 1** (optional): train a WRN-28-10 feature network (aka backbone)
39+
The weights of the feature network is downloaded in step 0, but you may also train from scracth by running
40+
41+
``` Bash
42+
python main_feat.py --outDir miniImageNet_WRN_60Epoch --cuda --dataset miniImageNet --nbEpoch 60
43+
```
44+
45+
### **Step 2**: Meta-training on Mini-ImageNet, e.g., 5-way-1-shot:
46+
47+
``` Bash
48+
python main.py --config config/miniImageNet_1shot.yaml --seed 100 --gpu 0
49+
```
50+
51+
### **Step 3**: Meta-testing on Mini-ImageNet with a checkpoint:
52+
53+
``` Bash
54+
python main.py --config config/miniImageNet_1shot.yaml --seed 100 --gpu 0 --ckpt cache/miniImageNet_1shot_K3_seed100/outputs_xx.xxx/netSIBBestxx.xxx.pth
55+
```
56+
57+
## Mini-ImageNet Results (LAST ckpt)
58+
59+
| Setup | 5-way-1-shot | 5-way-5-shot |
60+
| ------------- | -------------:| ------------:|
61+
| SIB (K=3) | 70.700% ± 0.585% | 80.045% ± 0.363%|
62+
| SIB (K=5) | 70.494 ± 0.619% | 80.192% ± 0.372%|
63+
64+
## CIFAR-FS Results (LAST ckpt)
65+
66+
| Setup | 5-way-1-shot | 5-way-5-shot |
67+
| ------------- | -------------:| ------------:|
68+
| SIB (K=3) | 79.763% ± 0.577% | 85.721% ± 0.369%|
69+
| SIB (K=5) | 79.627 ± 0.593% | 85.590% ± 0.375%|
+268
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
# ==============================================================================
14+
15+
import os
16+
import itertools
17+
import torch
18+
import torch.nn.functional as F
19+
from tensorboardX import SummaryWriter
20+
from utils.outils import progress_bar, AverageMeter, accuracy, getCi
21+
from utils.utils import to_device
22+
23+
class Algorithm:
24+
"""
25+
Algorithm logic is implemented here with training and validation functions etc.
26+
27+
:param args: experimental configurations
28+
:type args: EasyDict
29+
:param logger: logger
30+
:param netFeat: feature network
31+
:type netFeat: class `WideResNet` or `ConvNet_4_64`
32+
:param netSIB: Classifier/decoder
33+
:type netSIB: class `ClassifierSIB`
34+
:param optimizer: optimizer
35+
:type optimizer: torch.optim.SGD
36+
:param criterion: loss
37+
:type criterion: nn.CrossEntropyLoss
38+
"""
39+
def __init__(self, args, logger, netFeat, netSIB, optimizer, criterion):
40+
self.netFeat = netFeat
41+
self.netSIB = netSIB
42+
self.optimizer = optimizer
43+
self.criterion = criterion
44+
45+
self.nbIter = args.nbIter
46+
self.nStep = args.nStep
47+
self.outDir = args.outDir
48+
self.nFeat = args.nFeat
49+
self.batchSize = args.batchSize
50+
self.nEpisode = args.nEpisode
51+
self.momentum = args.momentum
52+
self.weightDecay = args.weightDecay
53+
54+
self.logger = logger
55+
self.device = torch.device('cuda' if args.cuda else 'cpu')
56+
57+
# Load pretrained model
58+
if args.resumeFeatPth :
59+
if args.cuda:
60+
param = torch.load(args.resumeFeatPth)
61+
else:
62+
param = torch.load(args.resumeFeatPth, map_location='cpu')
63+
self.netFeat.load_state_dict(param)
64+
msg = '\nLoading netFeat from {}'.format(args.resumeFeatPth)
65+
self.logger.info(msg)
66+
67+
if args.test:
68+
self.load_ckpt(args.ckptPth)
69+
70+
71+
def load_ckpt(self, ckptPth):
72+
"""
73+
Load checkpoint from ckptPth.
74+
75+
:param ckptPth: the path to the ckpt
76+
:type ckptPth: string
77+
"""
78+
param = torch.load(ckptPth)
79+
self.netFeat.load_state_dict(param['netFeat'])
80+
self.netSIB.load_state_dict(param['SIB'])
81+
lr = param['lr']
82+
self.optimizer = torch.optim.SGD(itertools.chain(*[self.netSIB.parameters(),]),
83+
lr,
84+
momentum=self.momentum,
85+
weight_decay=self.weightDecay,
86+
nesterov=True)
87+
msg = '\nLoading networks from {}'.format(ckptPth)
88+
self.logger.info(msg)
89+
90+
91+
def compute_grad_loss(self, clsScore, QueryLabel):
92+
"""
93+
Compute the loss between true gradients and synthetic gradients.
94+
"""
95+
# register hooks
96+
def require_nonleaf_grad(v):
97+
def hook(g):
98+
v.grad_nonleaf = g
99+
h = v.register_hook(hook)
100+
return h
101+
handle = require_nonleaf_grad(clsScore)
102+
103+
loss = self.criterion(clsScore, QueryLabel)
104+
loss.backward(retain_graph=True) # need to backward again
105+
106+
# remove hook
107+
handle.remove()
108+
109+
gradLogit = self.netSIB.dni(clsScore) # B * n x nKnovel
110+
gradLoss = F.mse_loss(gradLogit, clsScore.grad_nonleaf.detach())
111+
112+
return loss, gradLoss
113+
114+
115+
def validate(self, valLoader, lr=None, mode='val'):
116+
"""
117+
Run one epoch on val-set.
118+
119+
:param valLoader: the dataloader of val-set
120+
:type valLoader: class `ValLoader`
121+
:param float lr: learning rate for synthetic GD
122+
:param string mode: 'val' or 'train'
123+
"""
124+
if mode == 'test':
125+
nEpisode = self.nEpisode
126+
self.logger.info('\n\nTest mode: randomly sample {:d} episodes...'.format(nEpisode))
127+
elif mode == 'val':
128+
nEpisode = len(valLoader)
129+
self.logger.info('\n\nValidation mode: pre-defined {:d} episodes...'.format(nEpisode))
130+
valLoader = iter(valLoader)
131+
else:
132+
raise ValueError('mode is wrong!')
133+
134+
episodeAccLog = []
135+
top1 = AverageMeter()
136+
137+
self.netFeat.eval()
138+
#self.netSIB.eval() # set train mode, since updating bn helps to estimate better gradient
139+
140+
if lr is None:
141+
lr = self.optimizer.param_groups[0]['lr']
142+
143+
#for batchIdx, data in enumerate(valLoader):
144+
for batchIdx in range(nEpisode):
145+
data = valLoader.getEpisode() if mode == 'test' else next(valLoader)
146+
data = to_device(data, self.device)
147+
148+
SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
149+
data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
150+
data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)
151+
152+
with torch.no_grad():
153+
SupportFeat, QueryFeat = self.netFeat(SupportTensor), self.netFeat(QueryTensor)
154+
SupportFeat, QueryFeat, SupportLabel = \
155+
SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)
156+
157+
clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
158+
clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1], -1)
159+
QueryLabel = QueryLabel.view(-1)
160+
acc1 = accuracy(clsScore, QueryLabel, topk=(1,))
161+
top1.update(acc1[0].item(), clsScore.shape[0])
162+
163+
msg = 'Top1: {:.3f}%'.format(top1.avg)
164+
progress_bar(batchIdx, nEpisode, msg)
165+
episodeAccLog.append(acc1[0].item())
166+
167+
mean, ci95 = getCi(episodeAccLog)
168+
self.logger.info('Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.format(mean, ci95))
169+
return mean, ci95
170+
171+
172+
def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0) :
173+
"""
174+
Run one epoch on train-set.
175+
176+
:param trainLoader: the dataloader of train-set
177+
:type trainLoader: class `TrainLoader`
178+
:param valLoader: the dataloader of val-set
179+
:type valLoader: class `ValLoader`
180+
:param float lr: learning rate for synthetic GD
181+
:param float coeffGrad: deprecated
182+
"""
183+
bestAcc, ci = self.validate(valLoader, lr)
184+
self.logger.info('Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.format(bestAcc,ci))
185+
186+
self.netSIB.train()
187+
self.netFeat.eval()
188+
189+
losses = AverageMeter()
190+
top1 = AverageMeter()
191+
history = {'trainLoss' : [], 'trainAcc' : [], 'valAcc' : []}
192+
193+
for episode in range(self.nbIter):
194+
data = trainLoader.getBatch()
195+
data = to_device(data, self.device)
196+
197+
with torch.no_grad() :
198+
SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
199+
data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel']
200+
nC, nH, nW = SupportTensor.shape[2:]
201+
202+
SupportFeat = self.netFeat(SupportTensor.reshape(-1, nC, nH, nW))
203+
SupportFeat = SupportFeat.view(self.batchSize, -1, self.nFeat)
204+
205+
QueryFeat = self.netFeat(QueryTensor.reshape(-1, nC, nH, nW))
206+
QueryFeat = QueryFeat.view(self.batchSize, -1, self.nFeat)
207+
208+
if lr is None:
209+
lr = self.optimizer.param_groups[0]['lr']
210+
211+
self.optimizer.zero_grad()
212+
213+
clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
214+
clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1], -1)
215+
QueryLabel = QueryLabel.view(-1)
216+
217+
if coeffGrad > 0:
218+
loss, gradLoss = self.compute_grad_loss(clsScore, QueryLabel)
219+
loss = loss + gradLoss * coeffGrad
220+
else:
221+
loss = self.criterion(clsScore, QueryLabel)
222+
223+
loss.backward()
224+
self.optimizer.step()
225+
226+
acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
227+
top1.update(acc1[0].item(), clsScore.shape[0])
228+
losses.update(loss.item(), QueryFeat.shape[1])
229+
msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg)
230+
if coeffGrad > 0:
231+
msg = msg + '| gradLoss: {:.3f}%'.format(gradLoss.item())
232+
progress_bar(episode, self.nbIter, msg)
233+
234+
if episode % 1000 == 999 :
235+
acc, _ = self.validate(valLoader, lr)
236+
237+
if acc > bestAcc :
238+
msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(bestAcc , acc)
239+
self.logger.info(msg)
240+
241+
bestAcc = acc
242+
self.logger.info('Saving Best')
243+
torch.save({
244+
'lr': lr,
245+
'netFeat': self.netFeat.state_dict(),
246+
'SIB': self.netSIB.state_dict(),
247+
'nbStep': self.nStep,
248+
}, os.path.join(self.outDir, 'netSIBBest.pth'))
249+
250+
self.logger.info('Saving Last')
251+
torch.save({
252+
'lr': lr,
253+
'netFeat': self.netFeat.state_dict(),
254+
'SIB': self.netSIB.state_dict(),
255+
'nbStep': self.nStep,
256+
}, os.path.join(self.outDir, 'netSIBLast.pth'))
257+
258+
msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%'.format(
259+
episode, losses.avg, top1.avg, acc)
260+
self.logger.info(msg)
261+
history['trainLoss'].append(losses.avg)
262+
history['trainAcc'].append(top1.avg)
263+
history['valAcc'].append(acc)
264+
265+
losses = AverageMeter()
266+
top1 = AverageMeter()
267+
268+
return bestAcc, acc, history
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Few-shot dataset
2+
nClsEpisode: 5 # number of categories in each episode
3+
nSupport: 1 # number of samples per category in the support set
4+
nQuery: 15 # number of samples per category in the query set
5+
dataset: 'Cifar' # choices = ['miniImageNet', 'Cifar']
6+
7+
# Network
8+
nStep: 3 # number of synthetic gradient steps
9+
architecture: 'WRN_28_10' # choices = ['WRN_28_10', 'Conv64_4']
10+
batchSize: 1 # number of episodes in each batch
11+
12+
# Optimizer
13+
lr: 0.001 # lr is fixed
14+
weightDecay: 0.0005
15+
momentum: 0.9
16+
17+
# Training details
18+
expName: cifar-fs
19+
nbIter: 50000 # number of training iterations
20+
seed: 100 # can be reset with --seed
21+
gpu: '1' # can be reset with --gpu
22+
resumeFeatPth : './ckpts/CIFAR-FS/netFeatBest62.561.pth' # feat ckpt
23+
coeffGrad: 0 # grad loss coeff
24+
25+
# Testing
26+
nEpisode: 2000 # number of episodes for testing
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Few-shot dataset
2+
nClsEpisode: 5 # number of categories in each episode
3+
nSupport: 5 # number of samples per category in the support set
4+
nQuery: 15 # number of samples per category in the query set
5+
dataset: 'Cifar' # choices = ['miniImageNet', 'Cifar']
6+
7+
# Network
8+
nStep: 3 # number of synthetic gradient steps
9+
architecture: 'WRN_28_10' # choices = ['WRN_28_10', 'Conv64_4']
10+
batchSize: 1 # number of episodes in each batch
11+
12+
# Optimizer
13+
lr: 0.001 # lr is fixed
14+
weightDecay: 0.0005
15+
momentum: 0.9
16+
17+
# Training details
18+
expName: cifar-fs
19+
nbIter: 50000 # number of training iterations
20+
seed: 100 # can be reset with --seed
21+
gpu: '1' # can be reset with --gpu
22+
resumeFeatPth : './ckpts/CIFAR-FS/netFeatBest62.561.pth' # feat ckpt
23+
coeffGrad: 0 # grad loss coeff
24+
25+
# Testing
26+
nEpisode: 2000 # number of episodes for testing

0 commit comments

Comments
 (0)