1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ """
18
+ Helpers to train with 16-bit precision.
19
+ """
20
+
21
+ import numpy as np
22
+ import torch as th
23
+ import torch .nn as nn
24
+ from torch ._utils import _flatten_dense_tensors , _unflatten_dense_tensors
25
+
26
+
27
+ INITIAL_LOG_LOSS_SCALE = 20.0
28
+
29
+
30
+ def convert_module_to_f16 (l ):
31
+ """
32
+ Convert primitive modules to float16.
33
+ """
34
+ if isinstance (l , (nn .Conv1d , nn .Conv2d , nn .Conv3d )):
35
+ l .weight .data = l .weight .data .half ()
36
+ if l .bias is not None :
37
+ l .bias .data = l .bias .data .half ()
38
+
39
+
40
+ def convert_module_to_f32 (l ):
41
+ """
42
+ Convert primitive modules to float32, undoing convert_module_to_f16().
43
+ """
44
+ if isinstance (l , (nn .Conv1d , nn .Conv2d , nn .Conv3d )):
45
+ l .weight .data = l .weight .data .float ()
46
+ if l .bias is not None :
47
+ l .bias .data = l .bias .data .float ()
48
+
49
+
50
+ def make_master_params (param_groups_and_shapes ):
51
+ """
52
+ Copy model parameters into a (differently-shaped) list of full-precision
53
+ parameters.
54
+ """
55
+ master_params = []
56
+ for param_group , shape in param_groups_and_shapes :
57
+ master_param = nn .Parameter (
58
+ _flatten_dense_tensors (
59
+ [param .detach ().float () for (_ , param ) in param_group ]
60
+ ).view (shape )
61
+ )
62
+ master_param .requires_grad = True
63
+ master_params .append (master_param )
64
+ return master_params
65
+
66
+
67
+ def model_grads_to_master_grads (param_groups_and_shapes , master_params ):
68
+ """
69
+ Copy the gradients from the model parameters into the master parameters
70
+ from make_master_params().
71
+ """
72
+ for master_param , (param_group , shape ) in zip (
73
+ master_params , param_groups_and_shapes
74
+ ):
75
+ master_param .grad = _flatten_dense_tensors (
76
+ [param_grad_or_zeros (param ) for (_ , param ) in param_group ]
77
+ ).view (shape )
78
+
79
+
80
+ def master_params_to_model_params (param_groups_and_shapes , master_params ):
81
+ """
82
+ Copy the master parameter data back into the model parameters.
83
+ """
84
+ # Without copying to a list, if a generator is passed, this will
85
+ # silently not copy any parameters.
86
+ for master_param , (param_group , _ ) in zip (master_params , param_groups_and_shapes ):
87
+ for (_ , param ), unflat_master_param in zip (
88
+ param_group , unflatten_master_params (param_group , master_param .view (- 1 ))
89
+ ):
90
+ param .detach ().copy_ (unflat_master_param )
91
+
92
+
93
+ def unflatten_master_params (param_group , master_param ):
94
+ return _unflatten_dense_tensors (master_param , [param for (_ , param ) in param_group ])
95
+
96
+
97
+ def get_param_groups_and_shapes (named_model_params ):
98
+ named_model_params = list (named_model_params )
99
+ scalar_vector_named_params = (
100
+ [(n , p ) for (n , p ) in named_model_params if p .ndim <= 1 ],
101
+ (- 1 ),
102
+ )
103
+ matrix_named_params = (
104
+ [(n , p ) for (n , p ) in named_model_params if p .ndim > 1 ],
105
+ (1 , - 1 ),
106
+ )
107
+ return [scalar_vector_named_params , matrix_named_params ]
108
+
109
+
110
+ def master_params_to_state_dict (
111
+ model , param_groups_and_shapes , master_params , use_fp16
112
+ ):
113
+ if use_fp16 :
114
+ state_dict = model .state_dict ()
115
+ for master_param , (param_group , _ ) in zip (
116
+ master_params , param_groups_and_shapes
117
+ ):
118
+ for (name , _ ), unflat_master_param in zip (
119
+ param_group , unflatten_master_params (param_group , master_param .view (- 1 ))
120
+ ):
121
+ assert name in state_dict
122
+ state_dict [name ] = unflat_master_param
123
+ else :
124
+ state_dict = model .state_dict ()
125
+ for i , (name , _value ) in enumerate (model .named_parameters ()):
126
+ assert name in state_dict
127
+ state_dict [name ] = master_params [i ]
128
+ return state_dict
129
+
130
+
131
+ def state_dict_to_master_params (model , state_dict , use_fp16 ):
132
+ if use_fp16 :
133
+ named_model_params = [
134
+ (name , state_dict [name ]) for name , _ in model .named_parameters ()
135
+ ]
136
+ param_groups_and_shapes = get_param_groups_and_shapes (named_model_params )
137
+ master_params = make_master_params (param_groups_and_shapes )
138
+ else :
139
+ master_params = [state_dict [name ] for name , _ in model .named_parameters ()]
140
+ return master_params
141
+
142
+
143
+ def zero_master_grads (master_params ):
144
+ for param in master_params :
145
+ param .grad = None
146
+
147
+
148
+ def zero_grad (model_params ):
149
+ for param in model_params :
150
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
151
+ if param .grad is not None :
152
+ param .grad .detach_ ()
153
+ param .grad .zero_ ()
154
+
155
+
156
+ def param_grad_or_zeros (param ):
157
+ if param .grad is not None :
158
+ return param .grad .data .detach ()
159
+ else :
160
+ return th .zeros_like (param )
161
+
162
+
163
+ class MixedPrecisionTrainer :
164
+ def __init__ (
165
+ self ,
166
+ * ,
167
+ model ,
168
+ use_fp16 = False ,
169
+ fp16_scale_growth = 1e-3 ,
170
+ initial_lg_loss_scale = INITIAL_LOG_LOSS_SCALE ,
171
+ ):
172
+ self .model = model
173
+ self .use_fp16 = use_fp16
174
+ self .fp16_scale_growth = fp16_scale_growth
175
+
176
+ self .model_params = list (self .model .parameters ())
177
+ self .master_params = self .model_params
178
+ self .param_groups_and_shapes = None
179
+ self .lg_loss_scale = initial_lg_loss_scale
180
+
181
+ if self .use_fp16 :
182
+ self .param_groups_and_shapes = get_param_groups_and_shapes (
183
+ self .model .named_parameters ()
184
+ )
185
+ self .master_params = make_master_params (self .param_groups_and_shapes )
186
+ self .model .convert_to_fp16 ()
187
+
188
+ def zero_grad (self ):
189
+ zero_grad (self .model_params )
190
+
191
+ def backward (self , loss : th .Tensor ):
192
+ if self .use_fp16 :
193
+ loss_scale = 2 ** self .lg_loss_scale
194
+ (loss * loss_scale ).backward ()
195
+ else :
196
+ loss .backward ()
197
+
198
+ def optimize (self , opt : th .optim .Optimizer ):
199
+ if self .use_fp16 :
200
+ return self ._optimize_fp16 (opt )
201
+ else :
202
+ return self ._optimize_normal (opt )
203
+
204
+ def _optimize_fp16 (self , opt : th .optim .Optimizer ):
205
+ model_grads_to_master_grads (self .param_groups_and_shapes , self .master_params )
206
+ grad_norm , param_norm = self ._compute_norms (grad_scale = 2 ** self .lg_loss_scale )
207
+ if check_overflow (grad_norm ):
208
+ self .lg_loss_scale -= 1
209
+ zero_master_grads (self .master_params )
210
+ return False
211
+
212
+ for p in self .master_params :
213
+ p .grad .mul_ (1.0 / (2 ** self .lg_loss_scale ))
214
+ opt .step ()
215
+ zero_master_grads (self .master_params )
216
+ master_params_to_model_params (self .param_groups_and_shapes , self .master_params )
217
+ self .lg_loss_scale += self .fp16_scale_growth
218
+ return True
219
+
220
+ def _optimize_normal (self , opt : th .optim .Optimizer ):
221
+ grad_norm , param_norm = self ._compute_norms ()
222
+ opt .step ()
223
+ return True
224
+
225
+ def _compute_norms (self , grad_scale = 1.0 ):
226
+ grad_norm = 0.0
227
+ param_norm = 0.0
228
+ for p in self .master_params :
229
+ with th .no_grad ():
230
+ param_norm += th .norm (p , p = 2 , dtype = th .float32 ).item () ** 2
231
+ if p .grad is not None :
232
+ grad_norm += th .norm (p .grad , p = 2 , dtype = th .float32 ).item () ** 2
233
+ return np .sqrt (grad_norm ) / grad_scale , np .sqrt (param_norm )
234
+
235
+ def master_params_to_state_dict (self , master_params ):
236
+ return master_params_to_state_dict (
237
+ self .model , self .param_groups_and_shapes , master_params , self .use_fp16
238
+ )
239
+
240
+ def state_dict_to_master_params (self , state_dict ):
241
+ return state_dict_to_master_params (self .model , state_dict , self .use_fp16 )
242
+
243
+
244
+ def check_overflow (value ):
245
+ return (value == float ("inf" )) or (value == - float ("inf" )) or (value != value )
0 commit comments