diff --git a/uetasr/layers/normalization.py b/uetasr/layers/normalization.py index 7ba679e..799242a 100644 --- a/uetasr/layers/normalization.py +++ b/uetasr/layers/normalization.py @@ -47,3 +47,253 @@ def call(self, x): if self.bias: return norm_x + self.offset return norm_x + + +@tf.keras.utils.register_keras_serializable() +class AccumBatchNormalization(tf.keras.layers.Layer): + """ Custom Batch Normaliztion layer with gradient accumulation support. + Code from: https://github.com/andreped/GradientAccumulator + """ + def __init__( + self, + accum_steps: int = 1, + momentum: float = 0.99, + epsilon:float = 1e-3, + trainable:bool = True, + **kwargs + ): + """ Construct the AccumBatchNormalization layer. + + Args: + accum_steps (int): Update gradient in every accumulation steps. + momentum (float): Momentum used in variable update. + epsilon (float): Small value to aid numerical stability. + trainable (bool): Whether layer should be updated during training. + Different from training/inference mode. + """ + self.accum_steps = accum_steps + self.accum_steps_tf = tf.constant(accum_steps, + dtype=tf.int32, + name="accum_steps") + self.momentum = momentum + self.epsilon = epsilon + self.trainable = trainable + self.accum_step_counter = tf.Variable( + 0, dtype=tf.int32, trainable=False, name="accum_counter", + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + super().__init__(**kwargs) + + def build(self, input_shape): + """Builds layer and variables. + + Args: + input_shape: input feature map size. + """ + self.param_shape = input_shape[-1] + + self.beta = self.add_weight( + shape=(self.param_shape), + dtype=self.dtype, + initializer="zeros", + trainable=True, + name="beta", + experimental_autocast=False, + ) + + self.gamma = self.add_weight( + shape=(self.param_shape), + dtype=self.dtype, + initializer="ones", + trainable=True, + name="gamma", + experimental_autocast=False, + ) + + self.moving_mean = self.add_weight( + shape=(self.param_shape), + dtype=self.dtype, + initializer="zeros", + trainable=False, + name="moving_mean", + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.MEAN, + experimental_autocast=False, + ) + + self.moving_variance = self.add_weight( + shape=(self.param_shape), + dtype=self.dtype, + initializer="ones", + trainable=False, + name="moving_variance", + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.MEAN, + experimental_autocast=False, + ) + + self.accum_mean = self.add_weight( + shape=(self.param_shape), + dtype=self.dtype, + initializer="zeros", + trainable=False, + name="accum_mean", + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.MEAN, + experimental_autocast=False, + ) + + self.accum_variance = self.add_weight( + shape=(self.param_shape), + dtype=self.dtype, + initializer="zeros", # this should be "zeros" as we use it for accumulation + trainable=False, + name="accum_variance", + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.MEAN, + experimental_autocast=False, + ) + + def get_moving_average(self, statistic, new_value): + """Returns the moving average given a statistic and current estimate. + + Args: + statistic: summary statistic e.g. average across for single feature over multiple samples + new_value: statistic of single feature for single forward step. + Returns: + Updated statistic. + """ + decay = tf.convert_to_tensor(1.0 - self.momentum, name="decay") + if decay.dtype != statistic.dtype.base_dtype: + decay = tf.cast(decay, statistic.dtype.base_dtype) + delta = (statistic - tf.cast(new_value, statistic.dtype)) * decay + return statistic.assign_sub(delta) + + def update_variables(self, mean, var): + """Updates the batch normalization variables. + + Args: + mean: average for single feature + var: variance for single feature + """ + self.moving_mean.assign(self.get_moving_average(self.moving_mean, mean)) + self.moving_variance.assign(self.get_moving_average(self.moving_variance, var)) + + self.reset_accum() + + def reset_accum(self): + """Resets accumulator slots.""" + self.accum_mean.assign(tf.zeros_like(self.accum_mean)) + self.accum_variance.assign(tf.zeros_like(self.accum_variance)) + + self.accum_step_counter.assign(0) + + def call(self, inputs, training=None, mask=None): + """Performs the batch normalization step. + + Args: + inputs: input feature map to apply batch normalization across. + training: whether layer should be in training mode or not. + mask: whether to calculate statistics within masked region of feature map. + Returns: + Normalized feature map. + """ + self.inputs_dtype = inputs.dtype.base_dtype + if self.inputs_dtype in (tf.float16, tf.bfloat16): + # Do all math in float32 if given 16-bit inputs for numeric + # stability. In particular, it's very easy for variance to overflow + # in float16 and for safety we also choose to cast bfloat16 to + # float32. + inputs = tf.cast(inputs, self.dtype) + + if training: + assert len(inputs.shape) in (2, 4) + if len(inputs.shape) > 2: + axes = [0, 1, 2] + else: + axes = [0] + + # step accum count + self.accum_step_counter.assign_add(1) + + # get batch norm statistics + mean, var = tf.nn.moments(inputs, axes=axes, keepdims=False) + + # scale mean and variance to produce mean later + mean_scaled = mean / tf.cast(self.accum_steps_tf, mean.dtype) + var_scaled = var / tf.cast(self.accum_steps_tf, var.dtype) + + # accumulate statistics + self.accum_mean.assign_add(mean_scaled) + self.accum_variance.assign_add(var_scaled) + + # only update variables after n accumulation steps + tf.cond( + tf.equal(self.accum_step_counter, self.accum_steps_tf), + true_fn=lambda: self.update_variables(self.accum_mean, self.accum_variance), + false_fn=lambda: None + ) + else: + mean, var = self.moving_mean, self.moving_variance + + scale = self.gamma + offset = self.beta + + inv = tf.math.rsqrt(var + self.epsilon) + if scale is not None: + inv *= scale + + outputs = inputs * tf.cast(inv, inputs.dtype) + \ + tf.cast(offset - mean * inv if offset is not None else -mean * inv, inputs.dtype) + + # need to convert back to float16 after applying batch norm + if self.inputs_dtype in (tf.float16, tf.bfloat16): + outputs = tf.cast(outputs, self.dtype) + + return outputs + + @property + def trainable(self): + """Returns whether layer is trainable. + + Returns: + trainable boolean state. + """ + return self._trainable + + @trainable.setter + def trainable(self, value: bool): + """Sets trainable variable. + + Args: + value: which boolean state to change variable to. + """ + self._trainable = value + + @property + def _param_dtype(self): + """Raise parameters of fp16 batch norm to fp32 + + Returns: + dtype of params. + """ + if self.dtype == tf.float16 or self.dtype == tf.bfloat16: + return tf.float32 + else: + return self.dtype or tf.float32 + + def get_config(self): + """Returns configurations as dict. + + Returns: + Configuration file. + """ + config = { + 'accum_steps': self.accum_steps, + 'momentum': self.momentum, + 'epsilon': self.epsilon, + 'trainable': self.trainable, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/uetasr/models/accumulators.py b/uetasr/models/accumulators.py index a0b6f26..ef34d21 100644 --- a/uetasr/models/accumulators.py +++ b/uetasr/models/accumulators.py @@ -22,22 +22,22 @@ def __init__(self, dtype=tf.int32, name="accum_steps") self.accum_step_counter = tf.Variable( - 0, - dtype=tf.int32, - trainable=False, - name="accum_counter", + 0, dtype=tf.int32, trainable=False, name="accum_counter", synchronization=tf.VariableSynchronization.ON_READ, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) self.first_call = True - self.gradient_accumulation = None - self.reinit_grad_accum() self.mixed_precision = mixed_precision self.use_agc = use_agc self.clip_factor = clip_factor self.eps = eps + # TODO: Does this dynamically changes based on mixed precision/variables dtype? + self.dtype_value = self.dtype + self.gradient_accumulation = None + self.reinit_grad_accum() def train_step(self, data): + """Performs single train step.""" # need to reinit accumulator for models subclassed from tf.keras.Model if self.first_call: self.reinit_grad_accum() @@ -67,37 +67,36 @@ def train_step(self, data): sample_weight=sample_weight, regularization_losses=self.losses, ) - loss = loss / tf.cast( - self.accum_steps, - tf.float32) # MEAN reduction here IMPORTANT! Don't use SUM! + # MEAN reduction here IMPORTANT! Don't use SUM! + loss = loss / tf.cast(self.accum_steps, loss.dtype) # scale loss if mixed precision is enabled if self.mixed_precision: loss = self.optimizer.get_scaled_loss(loss) - # Calculate batch gradients -> these are scaled gradients if - # mixed precision is enabled + # Calculate batch gradients -> these are scaled gradients if mixed precision is enabled gradients = tape.gradient( loss, self.trainable_variables, - unconnected_gradients=tf.UnconnectedGradients.ZERO) + unconnected_gradients=tf.UnconnectedGradients.ZERO + ) # scale gradients if mixed precision is enabled if self.mixed_precision: gradients = self.optimizer.get_unscaled_gradients(gradients) - # apply adaptive gradient clipping - # -> should be AFTER unscaling gradients + # apply adaptive gradient clipping -> should be AFTER unscaling gradients if self.use_agc: - gradients = adaptive_clip_grad(self.trainable_variables, - gradients, - clip_factor=self.clip_factor, - eps=self.eps) + gradients = adaptive_clip_grad( + self.trainable_variables, + gradients, + clip_factor=self.clip_factor, + eps=self.eps + ) # Accumulate batch gradients for i in range(len(self.gradient_accumulation)): - self.gradient_accumulation[i].assign_add(gradients[i], - read_value=False) + self.gradient_accumulation[i].assign_add(gradients[i], read_value=False) # If accum_step_counter reach the accum_steps # then we apply accumulated gradients to update the variables @@ -107,12 +106,11 @@ def train_step(self, data): false_fn=lambda: None) # update metrics - self.compiled_metrics.update_state(y, - y_pred, - sample_weight=sample_weight) + self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight) return {m.name: m.result() for m in self.metrics} def apply_accu_gradients(self): + """Performs gradient update and resets slots afterwards.""" # apply accumulated gradients self.optimizer.apply_gradients( zip(self.gradient_accumulation, self.trainable_variables)) @@ -120,18 +118,19 @@ def apply_accu_gradients(self): # reset self.accum_step_counter.assign(0) for i in range(len(self.gradient_accumulation)): - self.gradient_accumulation[i].assign(tf.zeros_like( - self.trainable_variables[i], dtype=tf.float32), - read_value=False) + self.gradient_accumulation[i].assign( + tf.zeros_like(self.trainable_variables[i], + dtype=self.dtype_value), + read_value=False + ) def reinit_grad_accum(self): - # reinitialize gradient accumulator + """Reinitialized gradient accumulator slots.""" self.gradient_accumulation = [ - tf.Variable( - tf.zeros_like(v, dtype=tf.float32), - trainable=False, - name="accum_" + str(i), - synchronization=tf.VariableSynchronization.ON_READ, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - ) for i, v in enumerate(self.trainable_variables) + tf.Variable(tf.zeros_like(v, dtype=self.dtype_value), + trainable=False, + name="accum_" + str(i), + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) for i, v in enumerate(self.trainable_variables) ] diff --git a/uetasr/utils/adaptive_gradient_clip.py b/uetasr/utils/adaptive_gradient_clip.py index a4f54a9..3bf586b 100644 --- a/uetasr/utils/adaptive_gradient_clip.py +++ b/uetasr/utils/adaptive_gradient_clip.py @@ -3,10 +3,31 @@ # implementation from: https://github.com/sayakpaul/Adaptive-Gradient-Clipping def compute_norm(x, axis, keepdims): + """ + Computes the euclidean norm of a tensor :math:`x`. + + Args: + x: input tensor. + axis: which axis to compute norm across. + keepdims: whether to keep dimension after applying along axis. + + Returns: + Euclidean norm. + """ return tf.math.reduce_sum(x**2, axis=axis, keepdims=keepdims)**0.5 def unitwise_norm(x): + """ + Wrapper class which dynamically sets `axis` and `keepdims` given an + input `x` for calculating euclidean norm. + + Args: + x: input tensor. + + Returns: + Euclidean norm. + """ if len(x.get_shape()) <= 1: # Scalars and vectors axis = None keepdims = False @@ -25,7 +46,22 @@ def unitwise_norm(x): return compute_norm(x, axis, keepdims) -def adaptive_clip_grad(parameters, gradients, clip_factor=0.01, eps=1e-3): +def adaptive_clip_grad(parameters, + gradients, + clip_factor: float = 0.01, + eps: float = 1e-3): + """ + Performs adaptive gradient clipping on a given set of parameters and gradients. + + Args: + parameters: Which parameters to apply method on. + gradients: Which gradients to apply clipping on. + clip_factor: Sets upper limit for gradient clipping. + eps: Epsilon - small number in :math:`max()` to avoid zero norm and preserve numerical stability. + + Returns: + Updated gradients after gradient clipping. + """ new_grads = [] for (params, grads) in zip(parameters, gradients): p_norm = unitwise_norm(params)