Skip to content

Commit 2549ca4

Browse files
carmoccaawaelchlirohitgr7
authored
Clean up optimizer code (Lightning-AI#3587)
* Update optimizer code * Update CHANGELOG * Fix tuple of one list case * Update docs * Fix pep issue * Minor typo [skip-ci] * Use minimal match Co-authored-by: Adrian Wälchli <[email protected]> * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 0ec4107 commit 2549ca4

File tree

6 files changed

+223
-172
lines changed

6 files changed

+223
-172
lines changed

Diff for: CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313

1414
### Changed
1515

16+
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
1617

1718
### Deprecated
1819

Diff for: docs/source/optimizers.rst

+25-5
Original file line numberDiff line numberDiff line change
@@ -101,26 +101,46 @@ Every optimizer you use can be paired with any `LearningRateScheduler <https://p
101101
# Adam + LR scheduler
102102
def configure_optimizers(self):
103103
optimizer = Adam(...)
104-
scheduler = ReduceLROnPlateau(optimizer, ...)
104+
scheduler = LambdaLR(optimizer, ...)
105105
return [optimizer], [scheduler]
106106
107+
# The ReduceLROnPlateau scheduler requires a monitor
108+
def configure_optimizers(self):
109+
return {
110+
'optimizer': Adam(...),
111+
'scheduler': ReduceLROnPlateau(optimizer, ...),
112+
'monitor': 'metric_to_track'
113+
}
114+
107115
# Two optimizers each with a scheduler
108116
def configure_optimizers(self):
109117
optimizer1 = Adam(...)
110118
optimizer2 = SGD(...)
111-
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
119+
scheduler1 = LambdaLR(optimizer1, ...)
112120
scheduler2 = LambdaLR(optimizer2, ...)
113121
return [optimizer1, optimizer2], [scheduler1, scheduler2]
114122
123+
# Alternatively
124+
def configure_optimizers(self):
125+
optimizer1 = Adam(...)
126+
optimizer2 = SGD(...)
127+
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
128+
scheduler2 = LambdaLR(optimizer2, ...)
129+
return (
130+
{'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'},
131+
{'optimizer': optimizer2, 'lr_scheduler': scheduler2},
132+
)
133+
115134
# Same as above with additional params passed to the first scheduler
116135
def configure_optimizers(self):
117136
optimizers = [Adam(...), SGD(...)]
118137
schedulers = [
119138
{
120139
'scheduler': ReduceLROnPlateau(optimizers[0], ...),
121-
'monitor': 'val_recall', # Default: val_loss
140+
'monitor': 'metric_to_track',
122141
'interval': 'epoch',
123-
'frequency': 1
142+
'frequency': 1,
143+
'strict': True,
124144
},
125145
LambdaLR(optimizers[1], ...)
126146
]
@@ -144,7 +164,7 @@ To use multiple optimizers return > 1 optimizers from :meth:`pytorch_lightning.c
144164
145165
# Two optimizers, one scheduler for adam only
146166
def configure_optimizers(self):
147-
return [Adam(...), SGD(...)], [ReduceLROnPlateau()]
167+
return [Adam(...), SGD(...)], {'scheduler': ReduceLROnPlateau(), 'monitor': 'metric_to_track'}
148168
149169
Lightning will call each optimizer sequentially:
150170

Diff for: pytorch_lightning/trainer/connectors/optimizer_connector.py

+16-38
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717

1818
class OptimizerConnector:
19-
2019
def __init__(self, trainer):
2120
self.trainer = trainer
2221

@@ -41,21 +40,15 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
4140
# Take step if call to update_learning_rates matches the interval key and
4241
# the current step modulo the schedulers frequency is zero
4342
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0:
44-
# If instance of ReduceLROnPlateau, we need to pass validation loss
43+
# If instance of ReduceLROnPlateau, we need a monitor
44+
monitor_key, monitor_val = None, None
4545
if lr_scheduler['reduce_on_plateau']:
46-
try:
47-
monitor_key = lr_scheduler['monitor']
48-
except KeyError as e:
49-
m = "ReduceLROnPlateau requires returning a dict from configure_optimizers with the keyword " \
50-
"monitor=. For example:" \
51-
"return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'your_loss'}"
52-
raise MisconfigurationException(m)
53-
54-
if monitor_metrics is not None:
55-
monitor_val = monitor_metrics.get(monitor_key)
56-
else:
57-
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)
58-
46+
monitor_key = lr_scheduler['monitor']
47+
monitor_val = (
48+
monitor_metrics.get(monitor_key)
49+
if monitor_metrics is not None
50+
else self.trainer.logger_connector.callback_metrics.get(monitor_key)
51+
)
5952
if monitor_val is None:
6053
if lr_scheduler.get('strict', True):
6154
avail_metrics = self.trainer.logger_connector.callback_metrics.keys()
@@ -71,30 +64,15 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
7164
RuntimeWarning,
7265
)
7366
continue
74-
# update LR
75-
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
67+
# update LR
68+
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
69+
if lr_scheduler['reduce_on_plateau']:
7670
lr_scheduler['scheduler'].step(monitor_val)
77-
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
78-
79-
if self.trainer.dev_debugger.enabled:
80-
self.trainer.dev_debugger.track_lr_schedulers_update(
81-
self.trainer.batch_idx,
82-
interval,
83-
scheduler_idx,
84-
old_lr,
85-
new_lr,
86-
monitor_key,
87-
)
8871
else:
89-
# update LR
90-
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
9172
lr_scheduler['scheduler'].step()
92-
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
73+
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
9374

94-
if self.trainer.dev_debugger.enabled:
95-
self.trainer.dev_debugger.track_lr_schedulers_update(
96-
self.trainer.batch_idx,
97-
interval,
98-
scheduler_idx,
99-
old_lr, new_lr
100-
)
75+
if self.trainer.dev_debugger.enabled:
76+
self.trainer.dev_debugger.track_lr_schedulers_update(
77+
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
78+
)

Diff for: pytorch_lightning/trainer/optimizers.py

+63-70
Original file line numberDiff line numberDiff line change
@@ -21,111 +21,107 @@
2121

2222
from pytorch_lightning.core.lightning import LightningModule
2323
from pytorch_lightning.utilities import rank_zero_warn
24+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2425

2526

2627
class TrainerOptimizersMixin(ABC):
27-
28-
def init_optimizers(
29-
self,
30-
model: LightningModule
31-
) -> Tuple[List, List, List]:
28+
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
3229
optim_conf = model.configure_optimizers()
33-
3430
if optim_conf is None:
35-
rank_zero_warn('`LightningModule.configure_optimizers` returned `None`, '
36-
'this fit will run with no optimizer', UserWarning)
31+
rank_zero_warn(
32+
'`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer',
33+
UserWarning,
34+
)
3735
optim_conf = _MockOptimizer()
3836

37+
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
38+
monitor = None
39+
3940
# single output, single optimizer
4041
if isinstance(optim_conf, Optimizer):
41-
return [optim_conf], [], []
42-
42+
optimizers = [optim_conf]
4343
# two lists, optimizer + lr schedulers
44-
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
45-
and isinstance(optim_conf[0], list):
46-
optimizers, lr_schedulers = optim_conf
47-
lr_schedulers = self.configure_schedulers(lr_schedulers)
48-
return optimizers, lr_schedulers, []
49-
44+
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
45+
opt, sch = optim_conf
46+
optimizers = opt
47+
lr_schedulers = sch if isinstance(sch, list) else [sch]
5048
# single dictionary
5149
elif isinstance(optim_conf, dict):
52-
optimizer = optim_conf["optimizer"]
50+
optimizers = [optim_conf["optimizer"]]
5351
monitor = optim_conf.get('monitor', None)
54-
lr_scheduler = optim_conf.get("lr_scheduler", [])
55-
if lr_scheduler:
56-
lr_schedulers = self.configure_schedulers([lr_scheduler], monitor)
57-
else:
58-
lr_schedulers = []
59-
return [optimizer], lr_schedulers, []
60-
52+
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
6153
# multiple dictionaries
62-
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
54+
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
6355
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
64-
# take only lr wif exists and ot they are defined - not None
65-
lr_schedulers = [
66-
opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")
67-
]
68-
# take only freq wif exists and ot they are defined - not None
56+
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict]
6957
optimizer_frequencies = [
70-
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") is not None
58+
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
7159
]
72-
73-
# clean scheduler list
74-
if lr_schedulers:
75-
lr_schedulers = self.configure_schedulers(lr_schedulers)
7660
# assert that if frequencies are present, they are given for all optimizers
7761
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
7862
raise ValueError("A frequency must be given to each optimizer.")
79-
return optimizers, lr_schedulers, optimizer_frequencies
80-
8163
# single list or tuple, multiple optimizer
8264
elif isinstance(optim_conf, (list, tuple)):
83-
return list(optim_conf), [], []
84-
65+
optimizers = list(optim_conf)
8566
# unknown configuration
8667
else:
87-
raise ValueError(
68+
raise MisconfigurationException(
8869
'Unknown configuration for model optimizers.'
89-
' Output from `model.configure_optimizers()` should either be:'
90-
' * single output, single `torch.optim.Optimizer`'
91-
' * single output, list of `torch.optim.Optimizer`'
92-
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
93-
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
94-
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
95-
' a list of `torch.optim.lr_scheduler`'
96-
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
70+
' Output from `model.configure_optimizers()` should either be:\n'
71+
' * `torch.optim.Optimizer`\n'
72+
' * [`torch.optim.Optimizer`]\n'
73+
' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n'
74+
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
75+
' * A list of the previously described dict format, with an optional "frequency" key (int)'
76+
)
77+
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)
78+
79+
return optimizers, lr_schedulers, optimizer_frequencies
9780

9881
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
9982
# Convert each scheduler into dict structure with relevant information
10083
lr_schedulers = []
10184
default_config = {
102-
'interval': 'epoch', # default every epoch
103-
'frequency': 1, # default every epoch/batch
104-
'reduce_on_plateau': False
105-
} # most often not ReduceLROnPlateau scheduler
106-
107-
if monitor is not None:
108-
default_config['monitor'] = monitor
109-
85+
'scheduler': None,
86+
'interval': 'epoch', # after epoch is over
87+
'frequency': 1, # every epoch/batch
88+
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
89+
'monitor': monitor, # value to monitor for ReduceLROnPlateau
90+
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
91+
}
11092
for scheduler in schedulers:
11193
if isinstance(scheduler, dict):
94+
# check provided keys
95+
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
96+
if extra_keys:
97+
rank_zero_warn(f'Found unsupported keys in the lr scheduler dict: {extra_keys}', RuntimeWarning)
11298
if 'scheduler' not in scheduler:
113-
raise ValueError('Lr scheduler should have key `scheduler`',
114-
' with item being a lr scheduler')
99+
raise MisconfigurationException(
100+
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
101+
)
115102
scheduler['reduce_on_plateau'] = isinstance(
116-
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)
117-
103+
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
104+
)
105+
if scheduler['reduce_on_plateau'] and scheduler.get('monitor', None) is None:
106+
raise MisconfigurationException(
107+
'The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
108+
' For example: {"optimizer": optimizer, "lr_scheduler":'
109+
' {"scheduler": scheduler, "monitor": "your_loss"}}'
110+
)
118111
lr_schedulers.append({**default_config, **scheduler})
119-
120112
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
121-
lr_schedulers.append({**default_config, 'scheduler': scheduler,
122-
'reduce_on_plateau': True})
123-
113+
if monitor is None:
114+
raise MisconfigurationException(
115+
'`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
116+
' For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
117+
)
118+
lr_schedulers.append(
119+
{**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor}
120+
)
124121
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
125122
lr_schedulers.append({**default_config, 'scheduler': scheduler})
126123
else:
127-
raise ValueError(f'Input {scheduler} to lr schedulers '
128-
'is a invalid input.')
124+
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
129125
return lr_schedulers
130126

131127
def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
@@ -138,10 +134,7 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
138134
if scheduler.optimizer == optimizer:
139135
# Find the mro belonging to the base lr scheduler class
140136
for i, mro in enumerate(scheduler.__class__.__mro__):
141-
if (
142-
mro == optim.lr_scheduler._LRScheduler
143-
or mro == optim.lr_scheduler.ReduceLROnPlateau
144-
):
137+
if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau):
145138
idx = i
146139
state = scheduler.state_dict()
147140
else:

Diff for: tests/base/model_optimizers.py

-5
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,6 @@ def configure_optimizers__mixed_scheduling(self):
8181
return [optimizer1, optimizer2], \
8282
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]
8383

84-
def configure_optimizers__reduce_lr_on_plateau(self):
85-
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
86-
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
87-
return [optimizer], [lr_scheduler]
88-
8984
def configure_optimizers__param_groups(self):
9085
param_groups = [
9186
{'params': list(self.parameters())[:2], 'lr': self.learning_rate * 0.1},

0 commit comments

Comments
 (0)