@@ -442,7 +442,7 @@ class Transformer(Layer):
442
442
- **blinding**: bool. Whether or not use blinding.
443
443
- **seed**: A Python integer to use as random seed.
444
444
- **supports_masking**:bool. Whether or not support masking.
445
- - **attention_type**: str, Type of attention, the value must be one of { ``'scaled_dot_product'`` , ``'additive'`` }.
445
+ - **attention_type**: str, Type of attention, the value must be one of { ``'scaled_dot_product'`` , ``'cos'`` , ``'ln'`` , ``' additive'`` }.
446
446
- **output_type**: ``'mean'`` , ``'sum'`` or `None`. Whether or not use average/sum pooling for output.
447
447
448
448
References
@@ -490,6 +490,9 @@ def build(self, input_shape):
490
490
initializer = glorot_uniform (seed = self .seed ))
491
491
self .v = self .add_weight ('v' , shape = [self .att_embedding_size ], dtype = tf .float32 ,
492
492
initializer = glorot_uniform (seed = self .seed ))
493
+ elif self .attention_type == "ln" :
494
+ self .att_ln_q = LayerNormalization ()
495
+ self .att_ln_k = LayerNormalization ()
493
496
# if self.use_res:
494
497
# self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
495
498
# initializer=TruncatedNormal(seed=self.seed))
@@ -529,28 +532,42 @@ def call(self, inputs, mask=None, training=None, **kwargs):
529
532
queries = self .query_pe (queries )
530
533
keys = self .key_pe (queries )
531
534
532
- querys = tf .tensordot (queries , self .W_Query ,
533
- axes = (- 1 , 0 )) # None T_q D*head_num
534
- keys = tf .tensordot (keys , self .W_key , axes = (- 1 , 0 ))
535
- values = tf .tensordot (keys , self .W_Value , axes = (- 1 , 0 ))
535
+ Q = tf .tensordot (queries , self .W_Query ,
536
+ axes = (- 1 , 0 )) # N T_q D*h
537
+ K = tf .tensordot (keys , self .W_key , axes = (- 1 , 0 ))
538
+ V = tf .tensordot (keys , self .W_Value , axes = (- 1 , 0 ))
536
539
537
- # head_num*None T_q D
538
- querys = tf .concat (tf .split (querys , self .head_num , axis = 2 ), axis = 0 )
539
- keys = tf .concat (tf .split (keys , self .head_num , axis = 2 ), axis = 0 )
540
- values = tf .concat (tf .split (values , self .head_num , axis = 2 ), axis = 0 )
540
+ # h*N T_q D
541
+ Q_ = tf .concat (tf .split (Q , self .head_num , axis = 2 ), axis = 0 )
542
+ K_ = tf .concat (tf .split (K , self .head_num , axis = 2 ), axis = 0 )
543
+ V_ = tf .concat (tf .split (V , self .head_num , axis = 2 ), axis = 0 )
541
544
542
545
if self .attention_type == "scaled_dot_product" :
543
- # head_num*None T_q T_k
544
- outputs = tf .matmul (querys , keys , transpose_b = True )
546
+ # h*N T_q T_k
547
+ outputs = tf .matmul (Q_ , K_ , transpose_b = True )
545
548
546
- outputs = outputs / (keys .get_shape ().as_list ()[- 1 ] ** 0.5 )
549
+ outputs = outputs / (K_ .get_shape ().as_list ()[- 1 ] ** 0.5 )
550
+ elif self .attention_type == "cos" :
551
+ Q_cos = tf .nn .l2_normalize (Q_ , dim = - 1 )
552
+ K_cos = tf .nn .l2_normalize (K_ , dim = - 1 )
553
+
554
+ outputs = tf .matmul (Q_cos , K_cos , transpose_b = True ) # h*N T_q T_k
555
+
556
+ outputs = outputs * 20 # Scale
557
+ elif self .attention_type == 'ln' :
558
+ Q_ = self .att_ln_q (Q_ )
559
+ K_ = self .att_ln_k (K_ )
560
+
561
+ outputs = tf .matmul (Q_ , K_ , transpose_b = True ) # h*N T_q T_k
562
+ # Scale
563
+ outputs = outputs / (K_ .get_shape ().as_list ()[- 1 ] ** 0.5 )
547
564
elif self .attention_type == "additive" :
548
- querys_reshaped = tf .expand_dims (querys , axis = - 2 )
549
- keys_reshaped = tf .expand_dims (keys , axis = - 3 )
550
- outputs = tf .tanh (tf .nn .bias_add (querys_reshaped + keys_reshaped , self .b ))
565
+ Q_reshaped = tf .expand_dims (Q_ , axis = - 2 )
566
+ K_reshaped = tf .expand_dims (K_ , axis = - 3 )
567
+ outputs = tf .tanh (tf .nn .bias_add (Q_reshaped + K_reshaped , self .b ))
551
568
outputs = tf .squeeze (tf .tensordot (outputs , tf .expand_dims (self .v , axis = - 1 ), axes = [- 1 , 0 ]), axis = - 1 )
552
569
else :
553
- raise ValueError ("attention_type must be scaled_dot_product or additive" )
570
+ raise ValueError ("attention_type must be [ scaled_dot_product,cos,ln, additive] " )
554
571
555
572
key_masks = tf .tile (key_masks , [self .head_num , 1 ])
556
573
@@ -583,7 +600,7 @@ def call(self, inputs, mask=None, training=None, **kwargs):
583
600
outputs = self .dropout (outputs , training = training )
584
601
# Weighted sum
585
602
# ( h*N, T_q, C/h)
586
- result = tf .matmul (outputs , values )
603
+ result = tf .matmul (outputs , V_ )
587
604
result = tf .concat (tf .split (result , self .head_num , axis = 0 ), axis = 2 )
588
605
589
606
if self .use_res :
0 commit comments