@@ -47,3 +47,253 @@ def call(self, x):
47
47
if self .bias :
48
48
return norm_x + self .offset
49
49
return norm_x
50
+
51
+
52
+ @tf .keras .utils .register_keras_serializable ()
53
+ class AccumBatchNormalization (tf .keras .layers .Layer ):
54
+ """ Custom Batch Normaliztion layer with gradient accumulation support.
55
+ Code from: https://github.com/andreped/GradientAccumulator
56
+ """
57
+ def __init__ (
58
+ self ,
59
+ accum_steps : int = 1 ,
60
+ momentum : float = 0.99 ,
61
+ epsilon :float = 1e-3 ,
62
+ trainable :bool = True ,
63
+ ** kwargs
64
+ ):
65
+ """ Construct the AccumBatchNormalization layer.
66
+
67
+ Args:
68
+ accum_steps (int): Update gradient in every accumulation steps.
69
+ momentum (float): Momentum used in variable update.
70
+ epsilon (float): Small value to aid numerical stability.
71
+ trainable (bool): Whether layer should be updated during training.
72
+ Different from training/inference mode.
73
+ """
74
+ self .accum_steps = accum_steps
75
+ self .accum_steps_tf = tf .constant (accum_steps ,
76
+ dtype = tf .int32 ,
77
+ name = "accum_steps" )
78
+ self .momentum = momentum
79
+ self .epsilon = epsilon
80
+ self .trainable = trainable
81
+ self .accum_step_counter = tf .Variable (
82
+ 0 , dtype = tf .int32 , trainable = False , name = "accum_counter" ,
83
+ synchronization = tf .VariableSynchronization .ON_READ ,
84
+ aggregation = tf .VariableAggregation .ONLY_FIRST_REPLICA ,
85
+ )
86
+ super ().__init__ (** kwargs )
87
+
88
+ def build (self , input_shape ):
89
+ """Builds layer and variables.
90
+
91
+ Args:
92
+ input_shape: input feature map size.
93
+ """
94
+ self .param_shape = input_shape [- 1 ]
95
+
96
+ self .beta = self .add_weight (
97
+ shape = (self .param_shape ),
98
+ dtype = self .dtype ,
99
+ initializer = "zeros" ,
100
+ trainable = True ,
101
+ name = "beta" ,
102
+ experimental_autocast = False ,
103
+ )
104
+
105
+ self .gamma = self .add_weight (
106
+ shape = (self .param_shape ),
107
+ dtype = self .dtype ,
108
+ initializer = "ones" ,
109
+ trainable = True ,
110
+ name = "gamma" ,
111
+ experimental_autocast = False ,
112
+ )
113
+
114
+ self .moving_mean = self .add_weight (
115
+ shape = (self .param_shape ),
116
+ dtype = self .dtype ,
117
+ initializer = "zeros" ,
118
+ trainable = False ,
119
+ name = "moving_mean" ,
120
+ synchronization = tf .VariableSynchronization .ON_READ ,
121
+ aggregation = tf .VariableAggregation .MEAN ,
122
+ experimental_autocast = False ,
123
+ )
124
+
125
+ self .moving_variance = self .add_weight (
126
+ shape = (self .param_shape ),
127
+ dtype = self .dtype ,
128
+ initializer = "ones" ,
129
+ trainable = False ,
130
+ name = "moving_variance" ,
131
+ synchronization = tf .VariableSynchronization .ON_READ ,
132
+ aggregation = tf .VariableAggregation .MEAN ,
133
+ experimental_autocast = False ,
134
+ )
135
+
136
+ self .accum_mean = self .add_weight (
137
+ shape = (self .param_shape ),
138
+ dtype = self .dtype ,
139
+ initializer = "zeros" ,
140
+ trainable = False ,
141
+ name = "accum_mean" ,
142
+ synchronization = tf .VariableSynchronization .ON_READ ,
143
+ aggregation = tf .VariableAggregation .MEAN ,
144
+ experimental_autocast = False ,
145
+ )
146
+
147
+ self .accum_variance = self .add_weight (
148
+ shape = (self .param_shape ),
149
+ dtype = self .dtype ,
150
+ initializer = "zeros" , # this should be "zeros" as we use it for accumulation
151
+ trainable = False ,
152
+ name = "accum_variance" ,
153
+ synchronization = tf .VariableSynchronization .ON_READ ,
154
+ aggregation = tf .VariableAggregation .MEAN ,
155
+ experimental_autocast = False ,
156
+ )
157
+
158
+ def get_moving_average (self , statistic , new_value ):
159
+ """Returns the moving average given a statistic and current estimate.
160
+
161
+ Args:
162
+ statistic: summary statistic e.g. average across for single feature over multiple samples
163
+ new_value: statistic of single feature for single forward step.
164
+ Returns:
165
+ Updated statistic.
166
+ """
167
+ decay = tf .convert_to_tensor (1.0 - self .momentum , name = "decay" )
168
+ if decay .dtype != statistic .dtype .base_dtype :
169
+ decay = tf .cast (decay , statistic .dtype .base_dtype )
170
+ delta = (statistic - tf .cast (new_value , statistic .dtype )) * decay
171
+ return statistic .assign_sub (delta )
172
+
173
+ def update_variables (self , mean , var ):
174
+ """Updates the batch normalization variables.
175
+
176
+ Args:
177
+ mean: average for single feature
178
+ var: variance for single feature
179
+ """
180
+ self .moving_mean .assign (self .get_moving_average (self .moving_mean , mean ))
181
+ self .moving_variance .assign (self .get_moving_average (self .moving_variance , var ))
182
+
183
+ self .reset_accum ()
184
+
185
+ def reset_accum (self ):
186
+ """Resets accumulator slots."""
187
+ self .accum_mean .assign (tf .zeros_like (self .accum_mean ))
188
+ self .accum_variance .assign (tf .zeros_like (self .accum_variance ))
189
+
190
+ self .accum_step_counter .assign (0 )
191
+
192
+ def call (self , inputs , training = None , mask = None ):
193
+ """Performs the batch normalization step.
194
+
195
+ Args:
196
+ inputs: input feature map to apply batch normalization across.
197
+ training: whether layer should be in training mode or not.
198
+ mask: whether to calculate statistics within masked region of feature map.
199
+ Returns:
200
+ Normalized feature map.
201
+ """
202
+ self .inputs_dtype = inputs .dtype .base_dtype
203
+ if self .inputs_dtype in (tf .float16 , tf .bfloat16 ):
204
+ # Do all math in float32 if given 16-bit inputs for numeric
205
+ # stability. In particular, it's very easy for variance to overflow
206
+ # in float16 and for safety we also choose to cast bfloat16 to
207
+ # float32.
208
+ inputs = tf .cast (inputs , self .dtype )
209
+
210
+ if training :
211
+ assert len (inputs .shape ) in (2 , 4 )
212
+ if len (inputs .shape ) > 2 :
213
+ axes = [0 , 1 , 2 ]
214
+ else :
215
+ axes = [0 ]
216
+
217
+ # step accum count
218
+ self .accum_step_counter .assign_add (1 )
219
+
220
+ # get batch norm statistics
221
+ mean , var = tf .nn .moments (inputs , axes = axes , keepdims = False )
222
+
223
+ # scale mean and variance to produce mean later
224
+ mean_scaled = mean / tf .cast (self .accum_steps_tf , mean .dtype )
225
+ var_scaled = var / tf .cast (self .accum_steps_tf , var .dtype )
226
+
227
+ # accumulate statistics
228
+ self .accum_mean .assign_add (mean_scaled )
229
+ self .accum_variance .assign_add (var_scaled )
230
+
231
+ # only update variables after n accumulation steps
232
+ tf .cond (
233
+ tf .equal (self .accum_step_counter , self .accum_steps_tf ),
234
+ true_fn = lambda : self .update_variables (self .accum_mean , self .accum_variance ),
235
+ false_fn = lambda : None
236
+ )
237
+ else :
238
+ mean , var = self .moving_mean , self .moving_variance
239
+
240
+ scale = self .gamma
241
+ offset = self .beta
242
+
243
+ inv = tf .math .rsqrt (var + self .epsilon )
244
+ if scale is not None :
245
+ inv *= scale
246
+
247
+ outputs = inputs * tf .cast (inv , inputs .dtype ) + \
248
+ tf .cast (offset - mean * inv if offset is not None else - mean * inv , inputs .dtype )
249
+
250
+ # need to convert back to float16 after applying batch norm
251
+ if self .inputs_dtype in (tf .float16 , tf .bfloat16 ):
252
+ outputs = tf .cast (outputs , self .dtype )
253
+
254
+ return outputs
255
+
256
+ @property
257
+ def trainable (self ):
258
+ """Returns whether layer is trainable.
259
+
260
+ Returns:
261
+ trainable boolean state.
262
+ """
263
+ return self ._trainable
264
+
265
+ @trainable .setter
266
+ def trainable (self , value : bool ):
267
+ """Sets trainable variable.
268
+
269
+ Args:
270
+ value: which boolean state to change variable to.
271
+ """
272
+ self ._trainable = value
273
+
274
+ @property
275
+ def _param_dtype (self ):
276
+ """Raise parameters of fp16 batch norm to fp32
277
+
278
+ Returns:
279
+ dtype of params.
280
+ """
281
+ if self .dtype == tf .float16 or self .dtype == tf .bfloat16 :
282
+ return tf .float32
283
+ else :
284
+ return self .dtype or tf .float32
285
+
286
+ def get_config (self ):
287
+ """Returns configurations as dict.
288
+
289
+ Returns:
290
+ Configuration file.
291
+ """
292
+ config = {
293
+ 'accum_steps' : self .accum_steps ,
294
+ 'momentum' : self .momentum ,
295
+ 'epsilon' : self .epsilon ,
296
+ 'trainable' : self .trainable ,
297
+ }
298
+ base_config = super ().get_config ()
299
+ return dict (list (base_config .items ()) + list (config .items ()))
0 commit comments