Skip to content

Commit f937aea

Browse files
authored
add DCN-M and DCN-Mix (#316)
- Add DCN-M and DCN-Mix - update test files for DCN and DCNMix - set `h5py==2.10.0` for pytest - simplify `interaction.py ` - modify BilinearInteraction in FiBiNET
1 parent ab595ec commit f937aea

File tree

8 files changed

+314
-37
lines changed

8 files changed

+314
-37
lines changed

deepctr/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .activation import Dice
44
from .core import DNN, LocalActivationUnit, PredictionLayer
5-
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet,
5+
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix,
66
InnerProductLayer, InteractingLayer,
77
OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction,
88
FieldWiseBiInteraction, FwFMLayer)
@@ -20,6 +20,7 @@
2020
'FM': FM,
2121
'AFMLayer': AFMLayer,
2222
'CrossNet': CrossNet,
23+
'CrossNetMix': CrossNetMix,
2324
'BiInteractionPooling': BiInteractionPooling,
2425
'LocalActivationUnit': LocalActivationUnit,
2526
'Dice': Dice,

deepctr/layers/interaction.py

Lines changed: 184 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def get_config(self, ):
142142
config = {'attention_factor': self.attention_factor,
143143
'l2_reg_w': self.l2_reg_w, 'dropout_rate': self.dropout_rate, 'seed': self.seed}
144144
base_config = super(AFMLayer, self).get_config()
145-
return dict(list(base_config.items()) + list(config.items()))
145+
base_config.update(config)
146+
return base_config
146147

147148

148149
class BiInteractionPooling(Layer):
@@ -322,7 +323,8 @@ def get_config(self, ):
322323
config = {'layer_size': self.layer_size, 'split_half': self.split_half, 'activation': self.activation,
323324
'seed': self.seed}
324325
base_config = super(CIN, self).get_config()
325-
return dict(list(base_config.items()) + list(config.items()))
326+
base_config.update(config)
327+
return base_config
326328

327329

328330
class CrossNet(Layer):
@@ -340,16 +342,20 @@ class CrossNet(Layer):
340342
341343
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix
342344
345+
- **parameterization**: string, ``"vector"`` or ``"matrix"`` , way to parameterize the cross network.
346+
343347
- **seed**: A Python integer to use as random seed.
344348
345349
References
346350
- [Wang R, Fu B, Fu G, et al. Deep & cross network for ad click predictions[C]//Proceedings of the ADKDD'17. ACM, 2017: 12.](https://arxiv.org/abs/1708.05123)
347351
"""
348352

349-
def __init__(self, layer_num=2, l2_reg=0, seed=1024, **kwargs):
353+
def __init__(self, layer_num=2, parameterization='vector', l2_reg=0, seed=1024, **kwargs):
350354
self.layer_num = layer_num
355+
self.parameterization = parameterization
351356
self.l2_reg = l2_reg
352357
self.seed = seed
358+
print('CrossNet parameterization:', self.parameterization)
353359
super(CrossNet, self).__init__(**kwargs)
354360

355361
def build(self, input_shape):
@@ -359,12 +365,22 @@ def build(self, input_shape):
359365
"Unexpected inputs dimensions %d, expect to be 2 dimensions" % (len(input_shape),))
360366

361367
dim = int(input_shape[-1])
362-
self.kernels = [self.add_weight(name='kernel' + str(i),
363-
shape=(dim, 1),
364-
initializer=glorot_normal(
365-
seed=self.seed),
366-
regularizer=l2(self.l2_reg),
367-
trainable=True) for i in range(self.layer_num)]
368+
if self.parameterization == 'vector':
369+
self.kernels = [self.add_weight(name='kernel' + str(i),
370+
shape=(dim, 1),
371+
initializer=glorot_normal(
372+
seed=self.seed),
373+
regularizer=l2(self.l2_reg),
374+
trainable=True) for i in range(self.layer_num)]
375+
elif self.parameterization == 'matrix':
376+
self.kernels = [self.add_weight(name='kernel' + str(i),
377+
shape=(dim, dim),
378+
initializer=glorot_normal(
379+
seed=self.seed),
380+
regularizer=l2(self.l2_reg),
381+
trainable=True) for i in range(self.layer_num)]
382+
else: # error
383+
raise ValueError("parameterization should be 'vector' or 'matrix'")
368384
self.bias = [self.add_weight(name='bias' + str(i),
369385
shape=(dim, 1),
370386
initializer=Zeros(),
@@ -380,18 +396,152 @@ def call(self, inputs, **kwargs):
380396
x_0 = tf.expand_dims(inputs, axis=2)
381397
x_l = x_0
382398
for i in range(self.layer_num):
383-
xl_w = tf.tensordot(x_l, self.kernels[i], axes=(1, 0))
384-
dot_ = tf.matmul(x_0, xl_w)
385-
x_l = dot_ + self.bias[i] + x_l
399+
if self.parameterization == 'vector':
400+
xl_w = tf.tensordot(x_l, self.kernels[i], axes=(1, 0))
401+
dot_ = tf.matmul(x_0, xl_w)
402+
x_l = dot_ + self.bias[i]
403+
elif self.parameterization == 'matrix':
404+
dot_ = tf.einsum('ij,bjk->bik', self.kernels[i], x_l) # W * xi (bs, dim, 1)
405+
dot_ = dot_ + self.bias[i] # W * xi + b
406+
dot_ = x_0 * dot_ # x0 · (W * xi + b) Hadamard-product
407+
else: # error
408+
print("parameterization should be 'vector' or 'matrix'")
409+
x_l = dot_ + x_l
386410
x_l = tf.squeeze(x_l, axis=2)
387411
return x_l
388412

389413
def get_config(self, ):
390414

391-
config = {'layer_num': self.layer_num,
415+
config = {'layer_num': self.layer_num, 'parameterization': self.parameterization,
392416
'l2_reg': self.l2_reg, 'seed': self.seed}
393417
base_config = super(CrossNet, self).get_config()
394-
return dict(list(base_config.items()) + list(config.items()))
418+
base_config.update(config)
419+
return base_config
420+
421+
def compute_output_shape(self, input_shape):
422+
return input_shape
423+
424+
425+
class CrossNetMix(Layer):
426+
"""The Cross Network part of DCN-Mix model, which improves DCN-M by:
427+
1 add MOE to learn feature interactions in different subspaces
428+
2 add nonlinear transformations in low-dimensional space
429+
430+
Input shape
431+
- 2D tensor with shape: ``(batch_size, units)``.
432+
433+
Output shape
434+
- 2D tensor with shape: ``(batch_size, units)``.
435+
436+
Arguments
437+
- **low_rank** : Positive integer, dimensionality of low-rank sapce.
438+
439+
- **num_experts** : Positive integer, number of experts.
440+
441+
- **layer_num**: Positive integer, the cross layer number
442+
443+
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix
444+
445+
- **seed**: A Python integer to use as random seed.
446+
447+
References
448+
- [Wang R, Shivanna R, Cheng D Z, et al. DCN-M: Improved Deep & Cross Network for Feature Cross Learning in Web-scale Learning to Rank Systems[J]. 2020.](https://arxiv.org/abs/2008.13535)
449+
"""
450+
451+
def __init__(self, low_rank=32, num_experts=4, layer_num=2, l2_reg=0, seed=1024, **kwargs):
452+
self.low_rank = low_rank
453+
self.num_experts = num_experts
454+
self.layer_num = layer_num
455+
self.l2_reg = l2_reg
456+
self.seed = seed
457+
super(CrossNetMix, self).__init__(**kwargs)
458+
459+
def build(self, input_shape):
460+
461+
if len(input_shape) != 2:
462+
raise ValueError(
463+
"Unexpected inputs dimensions %d, expect to be 2 dimensions" % (len(input_shape),))
464+
465+
dim = int(input_shape[-1])
466+
467+
# U: (dim, low_rank)
468+
self.U_list = [self.add_weight(name='U_list' + str(i),
469+
shape=(self.num_experts, dim, self.low_rank),
470+
initializer=glorot_normal(
471+
seed=self.seed),
472+
regularizer=l2(self.l2_reg),
473+
trainable=True) for i in range(self.layer_num)]
474+
# V: (dim, low_rank)
475+
self.V_list = [self.add_weight(name='V_list' + str(i),
476+
shape=(self.num_experts, dim, self.low_rank),
477+
initializer=glorot_normal(
478+
seed=self.seed),
479+
regularizer=l2(self.l2_reg),
480+
trainable=True) for i in range(self.layer_num)]
481+
# C: (low_rank, low_rank)
482+
self.C_list = [self.add_weight(name='C_list' + str(i),
483+
shape=(self.num_experts, self.low_rank, self.low_rank),
484+
initializer=glorot_normal(
485+
seed=self.seed),
486+
regularizer=l2(self.l2_reg),
487+
trainable=True) for i in range(self.layer_num)]
488+
489+
self.gating = [tf.keras.layers.Dense(1, use_bias=False) for i in range(self.num_experts)]
490+
491+
self.bias = [self.add_weight(name='bias' + str(i),
492+
shape=(dim, 1),
493+
initializer=Zeros(),
494+
trainable=True) for i in range(self.layer_num)]
495+
# Be sure to call this somewhere!
496+
super(CrossNetMix, self).build(input_shape)
497+
498+
def call(self, inputs, **kwargs):
499+
if K.ndim(inputs) != 2:
500+
raise ValueError(
501+
"Unexpected inputs dimensions %d, expect to be 2 dimensions" % (K.ndim(inputs)))
502+
503+
x_0 = tf.expand_dims(inputs, axis=2)
504+
x_l = x_0
505+
for i in range(self.layer_num):
506+
output_of_experts = []
507+
gating_score_of_experts = []
508+
for expert_id in range(self.num_experts):
509+
# (1) G(x_l)
510+
# compute the gating score by x_l
511+
gating_score_of_experts.append(self.gating[expert_id](tf.squeeze(x_l, axis=2)))
512+
513+
# (2) E(x_l)
514+
# project the input x_l to $\mathbb{R}^{r}$
515+
v_x = tf.einsum('ij,bjk->bik', tf.transpose(self.V_list[i][expert_id]), x_l) # (bs, low_rank, 1)
516+
517+
# nonlinear activation in low rank space
518+
v_x = tf.nn.tanh(v_x)
519+
v_x = tf.einsum('ij,bjk->bik', self.C_list[i][expert_id], v_x) # (bs, low_rank, 1)
520+
v_x = tf.nn.tanh(v_x)
521+
522+
# project back to $\mathbb{R}^{d}$
523+
uv_x = tf.einsum('ij,bjk->bik', self.U_list[i][expert_id], v_x) # (bs, dim, 1)
524+
525+
dot_ = uv_x + self.bias[i]
526+
dot_ = x_0 * dot_ # Hadamard-product
527+
528+
output_of_experts.append(tf.squeeze(dot_, axis=2))
529+
530+
# (3) mixture of low-rank experts
531+
output_of_experts = tf.stack(output_of_experts, 2) # (bs, dim, num_experts)
532+
gating_score_of_experts = tf.stack(gating_score_of_experts, 1) # (bs, num_experts, 1)
533+
moe_out = tf.matmul(output_of_experts, tf.nn.softmax(gating_score_of_experts, 1))
534+
x_l = moe_out + x_l # (bs, dim, 1)
535+
x_l = tf.squeeze(x_l, axis=2)
536+
return x_l
537+
538+
def get_config(self, ):
539+
540+
config = {'low_rank': self.low_rank, 'num_experts': self.num_experts, 'layer_num': self.layer_num,
541+
'l2_reg': self.l2_reg, 'seed': self.seed}
542+
base_config = super(CrossNetMix, self).get_config()
543+
base_config.update(config)
544+
return base_config
395545

396546
def compute_output_shape(self, input_shape):
397547
return input_shape
@@ -527,7 +677,8 @@ def compute_output_shape(self, input_shape):
527677
def get_config(self, ):
528678
config = {'reduce_sum': self.reduce_sum, }
529679
base_config = super(InnerProductLayer, self).get_config()
530-
return dict(list(base_config.items()) + list(config.items()))
680+
base_config.update(config)
681+
return base_config
531682

532683

533684
class InteractingLayer(Layer):
@@ -619,7 +770,8 @@ def get_config(self, ):
619770
config = {'att_embedding_size': self.att_embedding_size, 'head_num': self.head_num, 'use_res': self.use_res,
620771
'seed': self.seed}
621772
base_config = super(InteractingLayer, self).get_config()
622-
return dict(list(base_config.items()) + list(config.items()))
773+
base_config.update(config)
774+
return base_config
623775

624776

625777
class OutterProductLayer(Layer):
@@ -762,7 +914,8 @@ def compute_output_shape(self, input_shape):
762914
def get_config(self, ):
763915
config = {'kernel_type': self.kernel_type, 'seed': self.seed}
764916
base_config = super(OutterProductLayer, self).get_config()
765-
return dict(list(base_config.items()) + list(config.items()))
917+
base_config.update(config)
918+
return base_config
766919

767920

768921
class FGCNNLayer(Layer):
@@ -866,7 +1019,8 @@ def get_config(self, ):
8661019
config = {'kernel_width': self.kernel_width, 'filters': self.filters, 'new_maps': self.new_maps,
8671020
'pooling_width': self.pooling_width}
8681021
base_config = super(FGCNNLayer, self).get_config()
869-
return dict(list(base_config.items()) + list(config.items()))
1022+
base_config.update(config)
1023+
return base_config
8701024

8711025
def _conv_output_shape(self, input_shape, kernel_size):
8721026
# channels_last
@@ -965,20 +1119,21 @@ def compute_mask(self, inputs, mask=None):
9651119
def get_config(self, ):
9661120
config = {'reduction_ratio': self.reduction_ratio, 'seed': self.seed}
9671121
base_config = super(SENETLayer, self).get_config()
968-
return dict(list(base_config.items()) + list(config.items()))
1122+
base_config.update(config)
1123+
return base_config
9691124

9701125

9711126
class BilinearInteraction(Layer):
9721127
"""BilinearInteraction Layer used in FiBiNET.
9731128
9741129
Input shape
975-
- A list of 3D tensor with shape: ``(batch_size,1,embedding_size)``.
1130+
- A list of 3D tensor with shape: ``(batch_size,1,embedding_size)``. Its length is ``filed_size``.
9761131
9771132
Output shape
978-
- 3D tensor with shape: ``(batch_size,1,embedding_size)``.
1133+
- 3D tensor with shape: ``(batch_size,filed_size*(filed_size-1)/2,embedding_size)``.
9791134
9801135
Arguments
981-
- **str** : String, types of bilinear functions used in this layer.
1136+
- **bilinear_type** : String, types of bilinear functions used in this layer.
9821137
9831138
- **seed** : A Python integer to use as random seed.
9841139
@@ -1034,18 +1189,20 @@ def call(self, inputs, **kwargs):
10341189
for v, w in zip(itertools.combinations(inputs, 2), self.W_list)]
10351190
else:
10361191
raise NotImplementedError
1037-
return concat_func(p)
1192+
output = concat_func(p, axis=1)
1193+
return output
10381194

10391195
def compute_output_shape(self, input_shape):
10401196
filed_size = len(input_shape)
10411197
embedding_size = input_shape[0][-1]
10421198

1043-
return (None, 1, filed_size * (filed_size - 1) // 2 * embedding_size)
1199+
return (None, filed_size * (filed_size - 1) // 2, embedding_size)
10441200

10451201
def get_config(self, ):
10461202
config = {'bilinear_type': self.bilinear_type, 'seed': self.seed}
10471203
base_config = super(BilinearInteraction, self).get_config()
1048-
return dict(list(base_config.items()) + list(config.items()))
1204+
base_config.update(config)
1205+
return base_config
10491206

10501207

10511208
class FieldWiseBiInteraction(Layer):
@@ -1171,7 +1328,8 @@ def compute_output_shape(self, input_shape):
11711328
def get_config(self, ):
11721329
config = {'use_bias': self.use_bias, 'seed': self.seed}
11731330
base_config = super(FieldWiseBiInteraction, self).get_config()
1174-
return dict(list(base_config.items()) + list(config.items()))
1331+
base_config.update(config)
1332+
return base_config
11751333

11761334

11771335
class FwFMLayer(Layer):

deepctr/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .autoint import AutoInt
33
from .ccpm import CCPM
44
from .dcn import DCN
5+
from .dcnmix import DCNMix
56
from .deepfm import DeepFM
67
from .dien import DIEN
78
from .din import DIN
@@ -19,5 +20,5 @@
1920
from .flen import FLEN
2021
from .fwfm import FwFM
2122

22-
__all__ = ["AFM", "CCPM","DCN", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN",
23+
__all__ = ["AFM", "CCPM", "DCN", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN",
2324
"WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM"]

0 commit comments

Comments
 (0)