Skip to content

Commit

Permalink
add DCN-M and DCN-Mix (#316)
Browse files Browse the repository at this point in the history
- Add DCN-M and DCN-Mix
- update test files for DCN and DCNMix
- set `h5py==2.10.0` for pytest
- simplify `interaction.py `
- modify BilinearInteraction in FiBiNET
  • Loading branch information
zanshuxun authored Jan 3, 2021
1 parent ab595ec commit f937aea
Show file tree
Hide file tree
Showing 8 changed files with 314 additions and 37 deletions.
3 changes: 2 additions & 1 deletion deepctr/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .activation import Dice
from .core import DNN, LocalActivationUnit, PredictionLayer
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet,
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix,
InnerProductLayer, InteractingLayer,
OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction,
FieldWiseBiInteraction, FwFMLayer)
Expand All @@ -20,6 +20,7 @@
'FM': FM,
'AFMLayer': AFMLayer,
'CrossNet': CrossNet,
'CrossNetMix': CrossNetMix,
'BiInteractionPooling': BiInteractionPooling,
'LocalActivationUnit': LocalActivationUnit,
'Dice': Dice,
Expand Down
210 changes: 184 additions & 26 deletions deepctr/layers/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def get_config(self, ):
config = {'attention_factor': self.attention_factor,
'l2_reg_w': self.l2_reg_w, 'dropout_rate': self.dropout_rate, 'seed': self.seed}
base_config = super(AFMLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config


class BiInteractionPooling(Layer):
Expand Down Expand Up @@ -322,7 +323,8 @@ def get_config(self, ):
config = {'layer_size': self.layer_size, 'split_half': self.split_half, 'activation': self.activation,
'seed': self.seed}
base_config = super(CIN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config


class CrossNet(Layer):
Expand All @@ -340,16 +342,20 @@ class CrossNet(Layer):
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix
- **parameterization**: string, ``"vector"`` or ``"matrix"`` , way to parameterize the cross network.
- **seed**: A Python integer to use as random seed.
References
- [Wang R, Fu B, Fu G, et al. Deep & cross network for ad click predictions[C]//Proceedings of the ADKDD'17. ACM, 2017: 12.](https://arxiv.org/abs/1708.05123)
"""

def __init__(self, layer_num=2, l2_reg=0, seed=1024, **kwargs):
def __init__(self, layer_num=2, parameterization='vector', l2_reg=0, seed=1024, **kwargs):
self.layer_num = layer_num
self.parameterization = parameterization
self.l2_reg = l2_reg
self.seed = seed
print('CrossNet parameterization:', self.parameterization)
super(CrossNet, self).__init__(**kwargs)

def build(self, input_shape):
Expand All @@ -359,12 +365,22 @@ def build(self, input_shape):
"Unexpected inputs dimensions %d, expect to be 2 dimensions" % (len(input_shape),))

dim = int(input_shape[-1])
self.kernels = [self.add_weight(name='kernel' + str(i),
shape=(dim, 1),
initializer=glorot_normal(
seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(self.layer_num)]
if self.parameterization == 'vector':
self.kernels = [self.add_weight(name='kernel' + str(i),
shape=(dim, 1),
initializer=glorot_normal(
seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(self.layer_num)]
elif self.parameterization == 'matrix':
self.kernels = [self.add_weight(name='kernel' + str(i),
shape=(dim, dim),
initializer=glorot_normal(
seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(self.layer_num)]
else: # error
raise ValueError("parameterization should be 'vector' or 'matrix'")
self.bias = [self.add_weight(name='bias' + str(i),
shape=(dim, 1),
initializer=Zeros(),
Expand All @@ -380,18 +396,152 @@ def call(self, inputs, **kwargs):
x_0 = tf.expand_dims(inputs, axis=2)
x_l = x_0
for i in range(self.layer_num):
xl_w = tf.tensordot(x_l, self.kernels[i], axes=(1, 0))
dot_ = tf.matmul(x_0, xl_w)
x_l = dot_ + self.bias[i] + x_l
if self.parameterization == 'vector':
xl_w = tf.tensordot(x_l, self.kernels[i], axes=(1, 0))
dot_ = tf.matmul(x_0, xl_w)
x_l = dot_ + self.bias[i]
elif self.parameterization == 'matrix':
dot_ = tf.einsum('ij,bjk->bik', self.kernels[i], x_l) # W * xi (bs, dim, 1)
dot_ = dot_ + self.bias[i] # W * xi + b
dot_ = x_0 * dot_ # x0 · (W * xi + b) Hadamard-product
else: # error
print("parameterization should be 'vector' or 'matrix'")
x_l = dot_ + x_l
x_l = tf.squeeze(x_l, axis=2)
return x_l

def get_config(self, ):

config = {'layer_num': self.layer_num,
config = {'layer_num': self.layer_num, 'parameterization': self.parameterization,
'l2_reg': self.l2_reg, 'seed': self.seed}
base_config = super(CrossNet, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config

def compute_output_shape(self, input_shape):
return input_shape


class CrossNetMix(Layer):
"""The Cross Network part of DCN-Mix model, which improves DCN-M by:
1 add MOE to learn feature interactions in different subspaces
2 add nonlinear transformations in low-dimensional space
Input shape
- 2D tensor with shape: ``(batch_size, units)``.
Output shape
- 2D tensor with shape: ``(batch_size, units)``.
Arguments
- **low_rank** : Positive integer, dimensionality of low-rank sapce.
- **num_experts** : Positive integer, number of experts.
- **layer_num**: Positive integer, the cross layer number
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix
- **seed**: A Python integer to use as random seed.
References
- [Wang R, Shivanna R, Cheng D Z, et al. DCN-M: Improved Deep & Cross Network for Feature Cross Learning in Web-scale Learning to Rank Systems[J]. 2020.](https://arxiv.org/abs/2008.13535)
"""

def __init__(self, low_rank=32, num_experts=4, layer_num=2, l2_reg=0, seed=1024, **kwargs):
self.low_rank = low_rank
self.num_experts = num_experts
self.layer_num = layer_num
self.l2_reg = l2_reg
self.seed = seed
super(CrossNetMix, self).__init__(**kwargs)

def build(self, input_shape):

if len(input_shape) != 2:
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 2 dimensions" % (len(input_shape),))

dim = int(input_shape[-1])

# U: (dim, low_rank)
self.U_list = [self.add_weight(name='U_list' + str(i),
shape=(self.num_experts, dim, self.low_rank),
initializer=glorot_normal(
seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(self.layer_num)]
# V: (dim, low_rank)
self.V_list = [self.add_weight(name='V_list' + str(i),
shape=(self.num_experts, dim, self.low_rank),
initializer=glorot_normal(
seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(self.layer_num)]
# C: (low_rank, low_rank)
self.C_list = [self.add_weight(name='C_list' + str(i),
shape=(self.num_experts, self.low_rank, self.low_rank),
initializer=glorot_normal(
seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(self.layer_num)]

self.gating = [tf.keras.layers.Dense(1, use_bias=False) for i in range(self.num_experts)]

self.bias = [self.add_weight(name='bias' + str(i),
shape=(dim, 1),
initializer=Zeros(),
trainable=True) for i in range(self.layer_num)]
# Be sure to call this somewhere!
super(CrossNetMix, self).build(input_shape)

def call(self, inputs, **kwargs):
if K.ndim(inputs) != 2:
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 2 dimensions" % (K.ndim(inputs)))

x_0 = tf.expand_dims(inputs, axis=2)
x_l = x_0
for i in range(self.layer_num):
output_of_experts = []
gating_score_of_experts = []
for expert_id in range(self.num_experts):
# (1) G(x_l)
# compute the gating score by x_l
gating_score_of_experts.append(self.gating[expert_id](tf.squeeze(x_l, axis=2)))

# (2) E(x_l)
# project the input x_l to $\mathbb{R}^{r}$
v_x = tf.einsum('ij,bjk->bik', tf.transpose(self.V_list[i][expert_id]), x_l) # (bs, low_rank, 1)

# nonlinear activation in low rank space
v_x = tf.nn.tanh(v_x)
v_x = tf.einsum('ij,bjk->bik', self.C_list[i][expert_id], v_x) # (bs, low_rank, 1)
v_x = tf.nn.tanh(v_x)

# project back to $\mathbb{R}^{d}$
uv_x = tf.einsum('ij,bjk->bik', self.U_list[i][expert_id], v_x) # (bs, dim, 1)

dot_ = uv_x + self.bias[i]
dot_ = x_0 * dot_ # Hadamard-product

output_of_experts.append(tf.squeeze(dot_, axis=2))

# (3) mixture of low-rank experts
output_of_experts = tf.stack(output_of_experts, 2) # (bs, dim, num_experts)
gating_score_of_experts = tf.stack(gating_score_of_experts, 1) # (bs, num_experts, 1)
moe_out = tf.matmul(output_of_experts, tf.nn.softmax(gating_score_of_experts, 1))
x_l = moe_out + x_l # (bs, dim, 1)
x_l = tf.squeeze(x_l, axis=2)
return x_l

def get_config(self, ):

config = {'low_rank': self.low_rank, 'num_experts': self.num_experts, 'layer_num': self.layer_num,
'l2_reg': self.l2_reg, 'seed': self.seed}
base_config = super(CrossNetMix, self).get_config()
base_config.update(config)
return base_config

def compute_output_shape(self, input_shape):
return input_shape
Expand Down Expand Up @@ -527,7 +677,8 @@ def compute_output_shape(self, input_shape):
def get_config(self, ):
config = {'reduce_sum': self.reduce_sum, }
base_config = super(InnerProductLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config


class InteractingLayer(Layer):
Expand Down Expand Up @@ -619,7 +770,8 @@ def get_config(self, ):
config = {'att_embedding_size': self.att_embedding_size, 'head_num': self.head_num, 'use_res': self.use_res,
'seed': self.seed}
base_config = super(InteractingLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config


class OutterProductLayer(Layer):
Expand Down Expand Up @@ -762,7 +914,8 @@ def compute_output_shape(self, input_shape):
def get_config(self, ):
config = {'kernel_type': self.kernel_type, 'seed': self.seed}
base_config = super(OutterProductLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config


class FGCNNLayer(Layer):
Expand Down Expand Up @@ -866,7 +1019,8 @@ def get_config(self, ):
config = {'kernel_width': self.kernel_width, 'filters': self.filters, 'new_maps': self.new_maps,
'pooling_width': self.pooling_width}
base_config = super(FGCNNLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config

def _conv_output_shape(self, input_shape, kernel_size):
# channels_last
Expand Down Expand Up @@ -965,20 +1119,21 @@ def compute_mask(self, inputs, mask=None):
def get_config(self, ):
config = {'reduction_ratio': self.reduction_ratio, 'seed': self.seed}
base_config = super(SENETLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config


class BilinearInteraction(Layer):
"""BilinearInteraction Layer used in FiBiNET.
Input shape
- A list of 3D tensor with shape: ``(batch_size,1,embedding_size)``.
- A list of 3D tensor with shape: ``(batch_size,1,embedding_size)``. Its length is ``filed_size``.
Output shape
- 3D tensor with shape: ``(batch_size,1,embedding_size)``.
- 3D tensor with shape: ``(batch_size,filed_size*(filed_size-1)/2,embedding_size)``.
Arguments
- **str** : String, types of bilinear functions used in this layer.
- **bilinear_type** : String, types of bilinear functions used in this layer.
- **seed** : A Python integer to use as random seed.
Expand Down Expand Up @@ -1034,18 +1189,20 @@ def call(self, inputs, **kwargs):
for v, w in zip(itertools.combinations(inputs, 2), self.W_list)]
else:
raise NotImplementedError
return concat_func(p)
output = concat_func(p, axis=1)
return output

def compute_output_shape(self, input_shape):
filed_size = len(input_shape)
embedding_size = input_shape[0][-1]

return (None, 1, filed_size * (filed_size - 1) // 2 * embedding_size)
return (None, filed_size * (filed_size - 1) // 2, embedding_size)

def get_config(self, ):
config = {'bilinear_type': self.bilinear_type, 'seed': self.seed}
base_config = super(BilinearInteraction, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config


class FieldWiseBiInteraction(Layer):
Expand Down Expand Up @@ -1171,7 +1328,8 @@ def compute_output_shape(self, input_shape):
def get_config(self, ):
config = {'use_bias': self.use_bias, 'seed': self.seed}
base_config = super(FieldWiseBiInteraction, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.update(config)
return base_config


class FwFMLayer(Layer):
Expand Down
3 changes: 2 additions & 1 deletion deepctr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .autoint import AutoInt
from .ccpm import CCPM
from .dcn import DCN
from .dcnmix import DCNMix
from .deepfm import DeepFM
from .dien import DIEN
from .din import DIN
Expand All @@ -19,5 +20,5 @@
from .flen import FLEN
from .fwfm import FwFM

__all__ = ["AFM", "CCPM","DCN", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN",
__all__ = ["AFM", "CCPM", "DCN", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN",
"WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM"]
Loading

0 comments on commit f937aea

Please sign in to comment.