Skip to content

Commit 21d072c

Browse files
committed
[EDIT] minor changes on gradient accum
1 parent 1dcc490 commit 21d072c

File tree

3 files changed

+320
-35
lines changed

3 files changed

+320
-35
lines changed

uetasr/layers/normalization.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,253 @@ def call(self, x):
4747
if self.bias:
4848
return norm_x + self.offset
4949
return norm_x
50+
51+
52+
@tf.keras.utils.register_keras_serializable()
53+
class AccumBatchNormalization(tf.keras.layers.Layer):
54+
""" Custom Batch Normaliztion layer with gradient accumulation support.
55+
Code from: https://github.com/andreped/GradientAccumulator
56+
"""
57+
def __init__(
58+
self,
59+
accum_steps: int = 1,
60+
momentum: float = 0.99,
61+
epsilon:float = 1e-3,
62+
trainable:bool = True,
63+
**kwargs
64+
):
65+
""" Construct the AccumBatchNormalization layer.
66+
67+
Args:
68+
accum_steps (int): Update gradient in every accumulation steps.
69+
momentum (float): Momentum used in variable update.
70+
epsilon (float): Small value to aid numerical stability.
71+
trainable (bool): Whether layer should be updated during training.
72+
Different from training/inference mode.
73+
"""
74+
self.accum_steps = accum_steps
75+
self.accum_steps_tf = tf.constant(accum_steps,
76+
dtype=tf.int32,
77+
name="accum_steps")
78+
self.momentum = momentum
79+
self.epsilon = epsilon
80+
self.trainable = trainable
81+
self.accum_step_counter = tf.Variable(
82+
0, dtype=tf.int32, trainable=False, name="accum_counter",
83+
synchronization=tf.VariableSynchronization.ON_READ,
84+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
85+
)
86+
super().__init__(**kwargs)
87+
88+
def build(self, input_shape):
89+
"""Builds layer and variables.
90+
91+
Args:
92+
input_shape: input feature map size.
93+
"""
94+
self.param_shape = input_shape[-1]
95+
96+
self.beta = self.add_weight(
97+
shape=(self.param_shape),
98+
dtype=self.dtype,
99+
initializer="zeros",
100+
trainable=True,
101+
name="beta",
102+
experimental_autocast=False,
103+
)
104+
105+
self.gamma = self.add_weight(
106+
shape=(self.param_shape),
107+
dtype=self.dtype,
108+
initializer="ones",
109+
trainable=True,
110+
name="gamma",
111+
experimental_autocast=False,
112+
)
113+
114+
self.moving_mean = self.add_weight(
115+
shape=(self.param_shape),
116+
dtype=self.dtype,
117+
initializer="zeros",
118+
trainable=False,
119+
name="moving_mean",
120+
synchronization=tf.VariableSynchronization.ON_READ,
121+
aggregation=tf.VariableAggregation.MEAN,
122+
experimental_autocast=False,
123+
)
124+
125+
self.moving_variance = self.add_weight(
126+
shape=(self.param_shape),
127+
dtype=self.dtype,
128+
initializer="ones",
129+
trainable=False,
130+
name="moving_variance",
131+
synchronization=tf.VariableSynchronization.ON_READ,
132+
aggregation=tf.VariableAggregation.MEAN,
133+
experimental_autocast=False,
134+
)
135+
136+
self.accum_mean = self.add_weight(
137+
shape=(self.param_shape),
138+
dtype=self.dtype,
139+
initializer="zeros",
140+
trainable=False,
141+
name="accum_mean",
142+
synchronization=tf.VariableSynchronization.ON_READ,
143+
aggregation=tf.VariableAggregation.MEAN,
144+
experimental_autocast=False,
145+
)
146+
147+
self.accum_variance = self.add_weight(
148+
shape=(self.param_shape),
149+
dtype=self.dtype,
150+
initializer="zeros", # this should be "zeros" as we use it for accumulation
151+
trainable=False,
152+
name="accum_variance",
153+
synchronization=tf.VariableSynchronization.ON_READ,
154+
aggregation=tf.VariableAggregation.MEAN,
155+
experimental_autocast=False,
156+
)
157+
158+
def get_moving_average(self, statistic, new_value):
159+
"""Returns the moving average given a statistic and current estimate.
160+
161+
Args:
162+
statistic: summary statistic e.g. average across for single feature over multiple samples
163+
new_value: statistic of single feature for single forward step.
164+
Returns:
165+
Updated statistic.
166+
"""
167+
decay = tf.convert_to_tensor(1.0 - self.momentum, name="decay")
168+
if decay.dtype != statistic.dtype.base_dtype:
169+
decay = tf.cast(decay, statistic.dtype.base_dtype)
170+
delta = (statistic - tf.cast(new_value, statistic.dtype)) * decay
171+
return statistic.assign_sub(delta)
172+
173+
def update_variables(self, mean, var):
174+
"""Updates the batch normalization variables.
175+
176+
Args:
177+
mean: average for single feature
178+
var: variance for single feature
179+
"""
180+
self.moving_mean.assign(self.get_moving_average(self.moving_mean, mean))
181+
self.moving_variance.assign(self.get_moving_average(self.moving_variance, var))
182+
183+
self.reset_accum()
184+
185+
def reset_accum(self):
186+
"""Resets accumulator slots."""
187+
self.accum_mean.assign(tf.zeros_like(self.accum_mean))
188+
self.accum_variance.assign(tf.zeros_like(self.accum_variance))
189+
190+
self.accum_step_counter.assign(0)
191+
192+
def call(self, inputs, training=None, mask=None):
193+
"""Performs the batch normalization step.
194+
195+
Args:
196+
inputs: input feature map to apply batch normalization across.
197+
training: whether layer should be in training mode or not.
198+
mask: whether to calculate statistics within masked region of feature map.
199+
Returns:
200+
Normalized feature map.
201+
"""
202+
self.inputs_dtype = inputs.dtype.base_dtype
203+
if self.inputs_dtype in (tf.float16, tf.bfloat16):
204+
# Do all math in float32 if given 16-bit inputs for numeric
205+
# stability. In particular, it's very easy for variance to overflow
206+
# in float16 and for safety we also choose to cast bfloat16 to
207+
# float32.
208+
inputs = tf.cast(inputs, self.dtype)
209+
210+
if training:
211+
assert len(inputs.shape) in (2, 4)
212+
if len(inputs.shape) > 2:
213+
axes = [0, 1, 2]
214+
else:
215+
axes = [0]
216+
217+
# step accum count
218+
self.accum_step_counter.assign_add(1)
219+
220+
# get batch norm statistics
221+
mean, var = tf.nn.moments(inputs, axes=axes, keepdims=False)
222+
223+
# scale mean and variance to produce mean later
224+
mean_scaled = mean / tf.cast(self.accum_steps_tf, mean.dtype)
225+
var_scaled = var / tf.cast(self.accum_steps_tf, var.dtype)
226+
227+
# accumulate statistics
228+
self.accum_mean.assign_add(mean_scaled)
229+
self.accum_variance.assign_add(var_scaled)
230+
231+
# only update variables after n accumulation steps
232+
tf.cond(
233+
tf.equal(self.accum_step_counter, self.accum_steps_tf),
234+
true_fn=lambda: self.update_variables(self.accum_mean, self.accum_variance),
235+
false_fn=lambda: None
236+
)
237+
else:
238+
mean, var = self.moving_mean, self.moving_variance
239+
240+
scale = self.gamma
241+
offset = self.beta
242+
243+
inv = tf.math.rsqrt(var + self.epsilon)
244+
if scale is not None:
245+
inv *= scale
246+
247+
outputs = inputs * tf.cast(inv, inputs.dtype) + \
248+
tf.cast(offset - mean * inv if offset is not None else -mean * inv, inputs.dtype)
249+
250+
# need to convert back to float16 after applying batch norm
251+
if self.inputs_dtype in (tf.float16, tf.bfloat16):
252+
outputs = tf.cast(outputs, self.dtype)
253+
254+
return outputs
255+
256+
@property
257+
def trainable(self):
258+
"""Returns whether layer is trainable.
259+
260+
Returns:
261+
trainable boolean state.
262+
"""
263+
return self._trainable
264+
265+
@trainable.setter
266+
def trainable(self, value: bool):
267+
"""Sets trainable variable.
268+
269+
Args:
270+
value: which boolean state to change variable to.
271+
"""
272+
self._trainable = value
273+
274+
@property
275+
def _param_dtype(self):
276+
"""Raise parameters of fp16 batch norm to fp32
277+
278+
Returns:
279+
dtype of params.
280+
"""
281+
if self.dtype == tf.float16 or self.dtype == tf.bfloat16:
282+
return tf.float32
283+
else:
284+
return self.dtype or tf.float32
285+
286+
def get_config(self):
287+
"""Returns configurations as dict.
288+
289+
Returns:
290+
Configuration file.
291+
"""
292+
config = {
293+
'accum_steps': self.accum_steps,
294+
'momentum': self.momentum,
295+
'epsilon': self.epsilon,
296+
'trainable': self.trainable,
297+
}
298+
base_config = super().get_config()
299+
return dict(list(base_config.items()) + list(config.items()))

uetasr/models/accumulators.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@ def __init__(self,
2222
dtype=tf.int32,
2323
name="accum_steps")
2424
self.accum_step_counter = tf.Variable(
25-
0,
26-
dtype=tf.int32,
27-
trainable=False,
28-
name="accum_counter",
25+
0, dtype=tf.int32, trainable=False, name="accum_counter",
2926
synchronization=tf.VariableSynchronization.ON_READ,
3027
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
3128
)
3229
self.first_call = True
33-
self.gradient_accumulation = None
34-
self.reinit_grad_accum()
3530
self.mixed_precision = mixed_precision
3631
self.use_agc = use_agc
3732
self.clip_factor = clip_factor
3833
self.eps = eps
34+
# TODO: Does this dynamically changes based on mixed precision/variables dtype?
35+
self.dtype_value = self.dtype
36+
self.gradient_accumulation = None
37+
self.reinit_grad_accum()
3938

4039
def train_step(self, data):
40+
"""Performs single train step."""
4141
# need to reinit accumulator for models subclassed from tf.keras.Model
4242
if self.first_call:
4343
self.reinit_grad_accum()
@@ -67,37 +67,36 @@ def train_step(self, data):
6767
sample_weight=sample_weight,
6868
regularization_losses=self.losses,
6969
)
70-
loss = loss / tf.cast(
71-
self.accum_steps,
72-
tf.float32) # MEAN reduction here IMPORTANT! Don't use SUM!
70+
# MEAN reduction here IMPORTANT! Don't use SUM!
71+
loss = loss / tf.cast(self.accum_steps, loss.dtype)
7372

7473
# scale loss if mixed precision is enabled
7574
if self.mixed_precision:
7675
loss = self.optimizer.get_scaled_loss(loss)
7776

78-
# Calculate batch gradients -> these are scaled gradients if
79-
# mixed precision is enabled
77+
# Calculate batch gradients -> these are scaled gradients if mixed precision is enabled
8078
gradients = tape.gradient(
8179
loss,
8280
self.trainable_variables,
83-
unconnected_gradients=tf.UnconnectedGradients.ZERO)
81+
unconnected_gradients=tf.UnconnectedGradients.ZERO
82+
)
8483

8584
# scale gradients if mixed precision is enabled
8685
if self.mixed_precision:
8786
gradients = self.optimizer.get_unscaled_gradients(gradients)
8887

89-
# apply adaptive gradient clipping
90-
# -> should be AFTER unscaling gradients
88+
# apply adaptive gradient clipping -> should be AFTER unscaling gradients
9189
if self.use_agc:
92-
gradients = adaptive_clip_grad(self.trainable_variables,
93-
gradients,
94-
clip_factor=self.clip_factor,
95-
eps=self.eps)
90+
gradients = adaptive_clip_grad(
91+
self.trainable_variables,
92+
gradients,
93+
clip_factor=self.clip_factor,
94+
eps=self.eps
95+
)
9696

9797
# Accumulate batch gradients
9898
for i in range(len(self.gradient_accumulation)):
99-
self.gradient_accumulation[i].assign_add(gradients[i],
100-
read_value=False)
99+
self.gradient_accumulation[i].assign_add(gradients[i], read_value=False)
101100

102101
# If accum_step_counter reach the accum_steps
103102
# then we apply accumulated gradients to update the variables
@@ -107,31 +106,31 @@ def train_step(self, data):
107106
false_fn=lambda: None)
108107

109108
# update metrics
110-
self.compiled_metrics.update_state(y,
111-
y_pred,
112-
sample_weight=sample_weight)
109+
self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)
113110
return {m.name: m.result() for m in self.metrics}
114111

115112
def apply_accu_gradients(self):
113+
"""Performs gradient update and resets slots afterwards."""
116114
# apply accumulated gradients
117115
self.optimizer.apply_gradients(
118116
zip(self.gradient_accumulation, self.trainable_variables))
119117

120118
# reset
121119
self.accum_step_counter.assign(0)
122120
for i in range(len(self.gradient_accumulation)):
123-
self.gradient_accumulation[i].assign(tf.zeros_like(
124-
self.trainable_variables[i], dtype=tf.float32),
125-
read_value=False)
121+
self.gradient_accumulation[i].assign(
122+
tf.zeros_like(self.trainable_variables[i],
123+
dtype=self.dtype_value),
124+
read_value=False
125+
)
126126

127127
def reinit_grad_accum(self):
128-
# reinitialize gradient accumulator
128+
"""Reinitialized gradient accumulator slots."""
129129
self.gradient_accumulation = [
130-
tf.Variable(
131-
tf.zeros_like(v, dtype=tf.float32),
132-
trainable=False,
133-
name="accum_" + str(i),
134-
synchronization=tf.VariableSynchronization.ON_READ,
135-
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
136-
) for i, v in enumerate(self.trainable_variables)
130+
tf.Variable(tf.zeros_like(v, dtype=self.dtype_value),
131+
trainable=False,
132+
name="accum_" + str(i),
133+
synchronization=tf.VariableSynchronization.ON_READ,
134+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
135+
) for i, v in enumerate(self.trainable_variables)
137136
]

0 commit comments

Comments
 (0)