diff --git a/keras_transformer/position.py b/keras_transformer/position.py index 5d3d347..eda0604 100644 --- a/keras_transformer/position.py +++ b/keras_transformer/position.py @@ -5,33 +5,6 @@ from keras.utils import get_custom_objects -def positional_signal(hidden_size: int, length: int, - min_timescale: float = 1.0, max_timescale: float = 1e4): - """ - Helper function, constructing basic positional encoding. - The code is partially based on implementation from Tensor2Tensor library - https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py - """ - - if hidden_size % 2 != 0: - raise ValueError( - f"The hidden dimension of the model must be divisible by 2." - f"Currently it is {hidden_size}") - position = K.arange(0, length, dtype=K.floatx()) - num_timescales = hidden_size // 2 - log_timescale_increment = K.constant( - (np.log(float(max_timescale) / float(min_timescale)) / - (num_timescales - 1)), - dtype=K.floatx()) - inv_timescales = ( - min_timescale * - K.exp(K.arange(num_timescales, dtype=K.floatx()) * - -log_timescale_increment)) - scaled_time = K.expand_dims(position, 1) * K.expand_dims(inv_timescales, 0) - signal = K.concatenate([K.sin(scaled_time), K.cos(scaled_time)], axis=1) - return K.expand_dims(signal, axis=0) - - class AddPositionalEncoding(Layer): """ Injects positional encoding signal described in section 3.5 of the original @@ -53,13 +26,40 @@ def get_config(self): return config def build(self, input_shape): - _, length, hidden_size = input_shape - self.signal = positional_signal( - hidden_size, length, self.min_timescale, self.max_timescale) + num_dims = len(input_shape) - 2 + channels = input_shape[-1] + num_timescales = channels // (num_dims * 2) + log_timescale_increment = K.constant( + np.log(float(self.max_timescale) / float(self.min_timescale)) / + (num_timescales - 1), + dtype=K.floatx()) + inv_timescales = self.min_timescale * K.exp( + K.arange(num_timescales, dtype=K.floatx()) * -log_timescale_increment) + self.signals = [] + for dim in range(num_dims): + length = input_shape[dim + 1] + position = K.arange(length, dtype=K.floatx()) + scaled_time = K.expand_dims(position, 1) * K.expand_dims(inv_timescales, 0) + signal = K.concatenate([K.sin(scaled_time), K.cos(scaled_time)], axis=1) + prepad = dim * 2 * num_timescales + postpad = channels - (dim + 1) * 2 * num_timescales + padded = [signal] + if prepad: + padded.insert(0, K.zeros((length, prepad))) + if postpad: + padded.append(K.zeros((length, postpad))) + signal = K.concatenate(padded, 1) + for _ in range(1 + dim): + signal = K.expand_dims(signal, 0) + for _ in range(num_dims - 1 - dim): + signal = K.expand_dims(signal, -2) + self.signals.append(signal) return super().build(input_shape) def call(self, inputs, **kwargs): - return inputs + self.signal + for signal in self.signals: + inputs = inputs + signal + return inputs class AddCoordinateEncoding(AddPositionalEncoding):