Skip to content

Commit 4ea5931

Browse files
authored
Merge pull request #437 from rwightman/agc
Adaptive Gradient Clipping (AGC) Impl
2 parents 5f9aff3 + 361fd0f commit 4ea5931

File tree

9 files changed

+106
-17
lines changed

9 files changed

+106
-17
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
## What's New
44

5+
### Feb 16, 2021
6+
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
7+
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
8+
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
9+
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
10+
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
11+
512
### Feb 12, 2021
613
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
714

@@ -238,6 +245,7 @@ Several (less common) features that I often utilize in my projects are included.
238245
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
239246
* Blur Pooling (https://arxiv.org/abs/1904.11486)
240247
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
248+
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
241249

242250
## Results
243251

timm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .xception_aligned import *
3232

3333
from .factory import create_model
34-
from .helpers import load_checkpoint, resume_checkpoint
34+
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
3535
from .layers import TestTimePoolHead, apply_test_time_pool
3636
from .layers import convert_splitbn_model
3737
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit

timm/models/helpers.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
113113
digits of the SHA256 hash of the contents of the file. The hash is used to
114114
ensure unique names and to verify the contents of the file. Default: False
115115
"""
116-
if cfg is None:
117-
cfg = getattr(model, 'default_cfg')
118-
if cfg is None or 'url' not in cfg or not cfg['url']:
119-
_logger.warning("Pretrained model URL does not exist, using random initialization.")
116+
cfg = cfg or getattr(model, 'default_cfg')
117+
if cfg is None or not cfg.get('url', None):
118+
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
120119
return
121120
url = cfg['url']
122121

@@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight):
174173

175174

176175
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
177-
if cfg is None:
178-
cfg = getattr(model, 'default_cfg')
179-
if cfg is None or 'url' not in cfg or not cfg['url']:
176+
cfg = cfg or getattr(model, 'default_cfg')
177+
if cfg is None or not cfg.get('url', None):
180178
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
181179
return
182180

@@ -376,3 +374,11 @@ def build_model_with_cfg(
376374
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
377375

378376
return model
377+
378+
379+
def model_parameters(model, exclude_head=False):
380+
if exclude_head:
381+
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
382+
return [p for p in model.parameters()][:-2]
383+
else:
384+
return model.parameters()

timm/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from .agc import adaptive_clip_grad
12
from .checkpoint_saver import CheckpointSaver
3+
from .clip_grad import dispatch_clip_grad
24
from .cuda import ApexScaler, NativeScaler
35
from .distributed import distribute_bn, reduce_tensor
46
from .jit import set_jit_legacy

timm/utils/agc.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
""" Adaptive Gradient Clipping
2+
3+
An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
4+
5+
@article{brock2021high,
6+
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
7+
title={High-Performance Large-Scale Image Recognition Without Normalization},
8+
journal={arXiv preprint arXiv:},
9+
year={2021}
10+
}
11+
12+
Code references:
13+
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
14+
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
15+
16+
Hacked together by / Copyright 2021 Ross Wightman
17+
"""
18+
import torch
19+
20+
21+
def unitwise_norm(x, norm_type=2.0):
22+
if x.ndim <= 1:
23+
return x.norm(norm_type)
24+
else:
25+
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
26+
# might need special cases for other weights (possibly MHA) where this may not be true
27+
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
28+
29+
30+
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
31+
if isinstance(parameters, torch.Tensor):
32+
parameters = [parameters]
33+
for p in parameters:
34+
if p.grad is None:
35+
continue
36+
p_data = p.detach()
37+
g_data = p.grad.detach()
38+
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
39+
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
40+
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
41+
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
42+
p.grad.detach().copy_(new_grads)

timm/utils/clip_grad.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
3+
from timm.utils.agc import adaptive_clip_grad
4+
5+
6+
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
7+
""" Dispatch to gradient clipping method
8+
9+
Args:
10+
parameters (Iterable): model parameters to clip
11+
value (float): clipping value/factor/norm, mode dependant
12+
mode (str): clipping mode, one of 'norm', 'value', 'agc'
13+
norm_type (float): p-norm, default 2.0
14+
"""
15+
if mode == 'norm':
16+
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
17+
elif mode == 'value':
18+
torch.nn.utils.clip_grad_value_(parameters, value)
19+
elif mode == 'agc':
20+
adaptive_clip_grad(parameters, value, norm_type=norm_type)
21+
else:
22+
assert False, f"Unknown clip mode ({mode})."
23+

timm/utils/cuda.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
amp = None
1212
has_apex = False
1313

14+
from .clip_grad import dispatch_clip_grad
15+
1416

1517
class ApexScaler:
1618
state_dict_key = "amp"
1719

18-
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
20+
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
1921
with amp.scale_loss(loss, optimizer) as scaled_loss:
2022
scaled_loss.backward(create_graph=create_graph)
2123
if clip_grad is not None:
22-
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
24+
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
2325
optimizer.step()
2426

2527
def state_dict(self):
@@ -37,12 +39,12 @@ class NativeScaler:
3739
def __init__(self):
3840
self._scaler = torch.cuda.amp.GradScaler()
3941

40-
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
42+
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
4143
self._scaler.scale(loss).backward(create_graph=create_graph)
4244
if clip_grad is not None:
4345
assert parameters is not None
4446
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
45-
torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
47+
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
4648
self._scaler.step(optimizer)
4749
self._scaler.update()
4850

timm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.4.3'
1+
__version__ = '0.4.4'

train.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torch.nn.parallel import DistributedDataParallel as NativeDDP
3030

3131
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
32-
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
32+
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
3333
from timm.utils import *
3434
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
3535
from timm.optim import create_optimizer
@@ -116,7 +116,8 @@
116116
help='weight decay (default: 0.0001)')
117117
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
118118
help='Clip gradient norm (default: None, no clipping)')
119-
119+
parser.add_argument('--clip-mode', type=str, default='norm',
120+
help='Gradient clipping mode. One of ("norm", "value", "agc")')
120121

121122

122123
# Learning rate schedule parameters
@@ -637,11 +638,16 @@ def train_one_epoch(
637638
optimizer.zero_grad()
638639
if loss_scaler is not None:
639640
loss_scaler(
640-
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
641+
loss, optimizer,
642+
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
643+
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
644+
create_graph=second_order)
641645
else:
642646
loss.backward(create_graph=second_order)
643647
if args.clip_grad is not None:
644-
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
648+
dispatch_clip_grad(
649+
model_parameters(model, exclude_head='agc' in args.clip_mode),
650+
value=args.clip_grad, mode=args.clip_mode)
645651
optimizer.step()
646652

647653
if model_ema is not None:

0 commit comments

Comments
 (0)