21
21
22
22
from pytorch_lightning .core .lightning import LightningModule
23
23
from pytorch_lightning .utilities import rank_zero_warn
24
+ from pytorch_lightning .utilities .exceptions import MisconfigurationException
24
25
25
26
26
27
class TrainerOptimizersMixin (ABC ):
27
-
28
- def init_optimizers (
29
- self ,
30
- model : LightningModule
31
- ) -> Tuple [List , List , List ]:
28
+ def init_optimizers (self , model : LightningModule ) -> Tuple [List , List , List ]:
32
29
optim_conf = model .configure_optimizers ()
33
-
34
30
if optim_conf is None :
35
- rank_zero_warn ('`LightningModule.configure_optimizers` returned `None`, '
36
- 'this fit will run with no optimizer' , UserWarning )
31
+ rank_zero_warn (
32
+ '`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer' ,
33
+ UserWarning ,
34
+ )
37
35
optim_conf = _MockOptimizer ()
38
36
37
+ optimizers , lr_schedulers , optimizer_frequencies = [], [], []
38
+ monitor = None
39
+
39
40
# single output, single optimizer
40
41
if isinstance (optim_conf , Optimizer ):
41
- return [optim_conf ], [], []
42
-
42
+ optimizers = [optim_conf ]
43
43
# two lists, optimizer + lr schedulers
44
- elif isinstance (optim_conf , (list , tuple )) and len (optim_conf ) == 2 \
45
- and isinstance (optim_conf [0 ], list ):
46
- optimizers , lr_schedulers = optim_conf
47
- lr_schedulers = self .configure_schedulers (lr_schedulers )
48
- return optimizers , lr_schedulers , []
49
-
44
+ elif isinstance (optim_conf , (list , tuple )) and len (optim_conf ) == 2 and isinstance (optim_conf [0 ], list ):
45
+ opt , sch = optim_conf
46
+ optimizers = opt
47
+ lr_schedulers = sch if isinstance (sch , list ) else [sch ]
50
48
# single dictionary
51
49
elif isinstance (optim_conf , dict ):
52
- optimizer = optim_conf ["optimizer" ]
50
+ optimizers = [ optim_conf ["optimizer" ] ]
53
51
monitor = optim_conf .get ('monitor' , None )
54
- lr_scheduler = optim_conf .get ("lr_scheduler" , [])
55
- if lr_scheduler :
56
- lr_schedulers = self .configure_schedulers ([lr_scheduler ], monitor )
57
- else :
58
- lr_schedulers = []
59
- return [optimizer ], lr_schedulers , []
60
-
52
+ lr_schedulers = [optim_conf ["lr_scheduler" ]] if "lr_scheduler" in optim_conf else []
61
53
# multiple dictionaries
62
- elif isinstance (optim_conf , (list , tuple )) and isinstance (optim_conf [ 0 ] , dict ):
54
+ elif isinstance (optim_conf , (list , tuple )) and all ( isinstance (d , dict ) for d in optim_conf ):
63
55
optimizers = [opt_dict ["optimizer" ] for opt_dict in optim_conf ]
64
- # take only lr wif exists and ot they are defined - not None
65
- lr_schedulers = [
66
- opt_dict ["lr_scheduler" ] for opt_dict in optim_conf if opt_dict .get ("lr_scheduler" )
67
- ]
68
- # take only freq wif exists and ot they are defined - not None
56
+ lr_schedulers = [opt_dict ["lr_scheduler" ] for opt_dict in optim_conf if "lr_scheduler" in opt_dict ]
69
57
optimizer_frequencies = [
70
- opt_dict ["frequency" ] for opt_dict in optim_conf if opt_dict .get ("frequency" ) is not None
58
+ opt_dict ["frequency" ] for opt_dict in optim_conf if opt_dict .get ("frequency" , None ) is not None
71
59
]
72
-
73
- # clean scheduler list
74
- if lr_schedulers :
75
- lr_schedulers = self .configure_schedulers (lr_schedulers )
76
60
# assert that if frequencies are present, they are given for all optimizers
77
61
if optimizer_frequencies and len (optimizer_frequencies ) != len (optimizers ):
78
62
raise ValueError ("A frequency must be given to each optimizer." )
79
- return optimizers , lr_schedulers , optimizer_frequencies
80
-
81
63
# single list or tuple, multiple optimizer
82
64
elif isinstance (optim_conf , (list , tuple )):
83
- return list (optim_conf ), [], []
84
-
65
+ optimizers = list (optim_conf )
85
66
# unknown configuration
86
67
else :
87
- raise ValueError (
68
+ raise MisconfigurationException (
88
69
'Unknown configuration for model optimizers.'
89
- ' Output from `model.configure_optimizers()` should either be:'
90
- ' * single output, single `torch.optim.Optimizer`'
91
- ' * single output, list of `torch.optim.Optimizer`'
92
- ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
93
- ' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
94
- ' * two outputs, first being a list of `torch.optim.Optimizer` second being'
95
- ' a list of `torch.optim.lr_scheduler`'
96
- ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)' )
70
+ ' Output from `model.configure_optimizers()` should either be:\n '
71
+ ' * `torch.optim.Optimizer`\n '
72
+ ' * [`torch.optim.Optimizer`]\n '
73
+ ' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n '
74
+ ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n '
75
+ ' * A list of the previously described dict format, with an optional "frequency" key (int)'
76
+ )
77
+ lr_schedulers = self .configure_schedulers (lr_schedulers , monitor = monitor )
78
+
79
+ return optimizers , lr_schedulers , optimizer_frequencies
97
80
98
81
def configure_schedulers (self , schedulers : list , monitor : Optional [str ] = None ):
99
82
# Convert each scheduler into dict structure with relevant information
100
83
lr_schedulers = []
101
84
default_config = {
102
- 'interval' : 'epoch' , # default every epoch
103
- 'frequency' : 1 , # default every epoch/batch
104
- 'reduce_on_plateau' : False
105
- } # most often not ReduceLROnPlateau scheduler
106
-
107
- if monitor is not None :
108
- default_config ['monitor' ] = monitor
109
-
85
+ 'scheduler' : None ,
86
+ 'interval' : 'epoch' , # after epoch is over
87
+ 'frequency' : 1 , # every epoch/batch
88
+ 'reduce_on_plateau' : False , # most often not ReduceLROnPlateau scheduler
89
+ 'monitor' : monitor , # value to monitor for ReduceLROnPlateau
90
+ 'strict' : True , # enforce that the monitor exists for ReduceLROnPlateau
91
+ }
110
92
for scheduler in schedulers :
111
93
if isinstance (scheduler , dict ):
94
+ # check provided keys
95
+ extra_keys = [k for k in scheduler .keys () if k not in default_config .keys ()]
96
+ if extra_keys :
97
+ rank_zero_warn (f'Found unsupported keys in the lr scheduler dict: { extra_keys } ' , RuntimeWarning )
112
98
if 'scheduler' not in scheduler :
113
- raise ValueError ('Lr scheduler should have key `scheduler`' ,
114
- ' with item being a lr scheduler' )
99
+ raise MisconfigurationException (
100
+ 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
101
+ )
115
102
scheduler ['reduce_on_plateau' ] = isinstance (
116
- scheduler ['scheduler' ], optim .lr_scheduler .ReduceLROnPlateau )
117
-
103
+ scheduler ['scheduler' ], optim .lr_scheduler .ReduceLROnPlateau
104
+ )
105
+ if scheduler ['reduce_on_plateau' ] and scheduler .get ('monitor' , None ) is None :
106
+ raise MisconfigurationException (
107
+ 'The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
108
+ ' For example: {"optimizer": optimizer, "lr_scheduler":'
109
+ ' {"scheduler": scheduler, "monitor": "your_loss"}}'
110
+ )
118
111
lr_schedulers .append ({** default_config , ** scheduler })
119
-
120
112
elif isinstance (scheduler , optim .lr_scheduler .ReduceLROnPlateau ):
121
- lr_schedulers .append ({** default_config , 'scheduler' : scheduler ,
122
- 'reduce_on_plateau' : True })
123
-
113
+ if monitor is None :
114
+ raise MisconfigurationException (
115
+ '`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
116
+ ' For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
117
+ )
118
+ lr_schedulers .append (
119
+ {** default_config , 'scheduler' : scheduler , 'reduce_on_plateau' : True , 'monitor' : monitor }
120
+ )
124
121
elif isinstance (scheduler , optim .lr_scheduler ._LRScheduler ):
125
122
lr_schedulers .append ({** default_config , 'scheduler' : scheduler })
126
123
else :
127
- raise ValueError (f'Input { scheduler } to lr schedulers '
128
- 'is a invalid input.' )
124
+ raise ValueError (f'The provided lr scheduler "{ scheduler } " is invalid' )
129
125
return lr_schedulers
130
126
131
127
def reinit_scheduler_properties (self , optimizers : list , schedulers : list ):
@@ -138,10 +134,7 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
138
134
if scheduler .optimizer == optimizer :
139
135
# Find the mro belonging to the base lr scheduler class
140
136
for i , mro in enumerate (scheduler .__class__ .__mro__ ):
141
- if (
142
- mro == optim .lr_scheduler ._LRScheduler
143
- or mro == optim .lr_scheduler .ReduceLROnPlateau
144
- ):
137
+ if mro in (optim .lr_scheduler ._LRScheduler , optim .lr_scheduler .ReduceLROnPlateau ):
145
138
idx = i
146
139
state = scheduler .state_dict ()
147
140
else :
0 commit comments