Skip to content

Commit

Permalink
Refactor input module of linear part
Browse files Browse the repository at this point in the history
  • Loading branch information
浅梦 authored Jul 21, 2019
1 parent 1404f0d commit 8182ea3
Show file tree
Hide file tree
Showing 26 changed files with 209 additions and 129 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Steps to reproduce the behavior:
**Operating environment(运行环境):**
- python version [e.g. 3.4, 3.6]
- tensorflow version [e.g. 1.4.0, 1.12.0]
- deepctr version [e.g. 0.2.3,]
- deepctr version [e.g. 0.5.2,]

**Additional context**
Add any other context about the problem here.
3 changes: 2 additions & 1 deletion .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ labels: question
assignees: ''

---
Please refer to the [FAQ](https://deepctr-doc.readthedocs.io/en/latest/FAQ.html) in doc and search for the [related issues](https://github.com/shenweichen/DeepCTR/issues) before you ask the question.

**Describe the question(问题描述)**
A clear and concise description of what the question is.
Expand All @@ -16,4 +17,4 @@ Add any other context about the problem here.
**Operating environment(运行环境):**
- python version [e.g. 3.6]
- tensorflow version [e.g. 1.4.0,]
- deepctr version [e.g. 0.3.2,]
- deepctr version [e.g. 0.5.2,]
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ env:
#Not Support- TF_VERSION=1.7.1
#Not Support- TF_VERSION=1.8.0
#- TF_VERSION=1.8.0
- TF_VERSION=1.10.0 #- TF_VERSION=1.10.1
#- TF_VERSION=1.10.0 >50 mins limit #- TF_VERSION=1.10.1
# - TF_VERSION=1.11.0
#- TF_VERSION=1.5.1 #- TF_VERSION=1.5.0
- TF_VERSION=1.6.0
Expand Down
2 changes: 1 addition & 1 deletion deepctr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import models
from .utils import check_version

__version__ = '0.5.1'
__version__ = '0.5.2'
check_version(__version__)
47 changes: 24 additions & 23 deletions deepctr/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from itertools import chain

from tensorflow.python.keras.initializers import RandomNormal
from tensorflow.python.keras.layers import Concatenate, Dense, Embedding, Input, add,Flatten
from tensorflow.python.keras.layers import Embedding, Input, Flatten
from tensorflow.python.keras.regularizers import l2

from .layers.sequence import SequencePoolingLayer
from .layers.utils import Hash,concat_fun
from .layers.utils import Hash,concat_fun,Linear


class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype','embedding_name','embedding'])):
Expand Down Expand Up @@ -45,12 +45,14 @@ def __new__(cls, name, dimension, maxlen, combiner="mean", use_hash=False, dtype

def get_fixlen_feature_names(feature_columns):
features = build_input_features(feature_columns, include_varlen=False,include_fixlen=True)
return features.keys()
return list(features.keys())

def get_varlen_feature_names(feature_columns):
features = build_input_features(feature_columns, include_varlen=True,include_fixlen=False)
return features.keys()
return list(features.keys())

def get_inputs_list(inputs):
return list(chain(*list(map(lambda x: x.values(), filter(lambda x: x is not None, inputs)))))

def build_input_features(feature_columns, include_varlen=True, mask_zero=True, prefix='',include_fixlen=True):
input_features = OrderedDict()
Expand All @@ -61,7 +63,7 @@ def build_input_features(feature_columns, include_varlen=True, mask_zero=True, p
shape=(1,), name=prefix+fc.name, dtype=fc.dtype)
elif isinstance(fc,DenseFeat):
input_features[fc.name] = Input(
shape=(1,), name=prefix + fc.name, dtype=fc.dtype)
shape=(fc.dimension,), name=prefix + fc.name, dtype=fc.dtype)
if include_varlen:
for fc in feature_columns:
if isinstance(fc,VarLenSparseFeat):
Expand Down Expand Up @@ -138,8 +140,7 @@ def get_embedding_vec_list(embedding_dict, input_dict, sparse_feature_columns, r
return embedding_vec_list


def get_inputs_list(inputs):
return list(chain(*list(map(lambda x: x.values(), filter(lambda x: x is not None, inputs)))))


def create_embedding_matrix(feature_columns,l2_reg,init_std,seed,embedding_size, prefix="",seq_mask_zero=True):
sparse_feature_columns = list(
Expand All @@ -155,24 +156,24 @@ def get_linear_logit(features, feature_columns, units=1, l2_reg=0, init_std=0.00
linear_emb_list = [input_from_feature_columns(features,feature_columns,1,l2_reg,init_std,seed,prefix=prefix+str(i))[0] for i in range(units)]
_, dense_input_list = input_from_feature_columns(features,feature_columns,1,l2_reg,init_std,seed,prefix=prefix)

if len(linear_emb_list[0]) > 1:
linear_term = concat_fun([add(linear_emb) for linear_emb in linear_emb_list])
elif len(linear_emb_list[0]) == 1:
linear_term = concat_fun([linear_emb[0] for linear_emb in linear_emb_list])
else:
linear_term = None

if len(dense_input_list) > 0:
dense_input__ = dense_input_list[0] if len(
dense_input_list) == 1 else Concatenate()(dense_input_list)
linear_dense_logit = Dense(
units, activation=None, use_bias=False, kernel_regularizer=l2(l2_reg))(dense_input__)
if linear_term is not None:
linear_term = add([linear_dense_logit, linear_term])
linear_logit_list = []
for i in range(units):

if len(linear_emb_list[0])>0 and len(dense_input_list) >0:
sparse_input = concat_fun(linear_emb_list[i])
dense_input = concat_fun(dense_input_list)
linear_logit = Linear(l2_reg,mode=2)([sparse_input,dense_input])
elif len(linear_emb_list[0])>0:
sparse_input = concat_fun(linear_emb_list[i])
linear_logit = Linear(l2_reg,mode=0)(sparse_input)
elif len(dense_input_list) >0:
dense_input = concat_fun(dense_input_list)
linear_logit = Linear(l2_reg,mode=1)(dense_input)
else:
linear_term = linear_dense_logit
raise NotImplementedError
linear_logit_list.append(linear_logit)

return linear_term
return concat_fun(linear_logit_list)

def embedding_lookup(sparse_embedding_dict,sparse_input_dict,sparse_feature_columns,return_feat_list=(), mask_feat_list=()):
embedding_vec_list = []
Expand Down
3 changes: 2 additions & 1 deletion deepctr/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
KMaxPooling, SequencePoolingLayer,
Transformer, DynamicGRU)
from .utils import NoMask, Hash
from .utils import NoMask, Hash,Linear

custom_objects = {'tf': tf,
'InnerProductLayer': InnerProductLayer,
Expand All @@ -34,6 +34,7 @@
'KMaxPooling': KMaxPooling,
'FGCNNLayer': FGCNNLayer,
'Hash': Hash,
'Linear':Linear,
'DynamicGRU': DynamicGRU,
'SENETLayer':SENETLayer,
'BilinearInteraction':BilinearInteraction,
Expand Down
51 changes: 49 additions & 2 deletions deepctr/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def compute_mask(self, inputs, mask):
return None



class Hash(tf.keras.layers.Layer):
"""
hash the input to [0,num_buckets)
Expand All @@ -43,7 +42,7 @@ def call(self, x, mask=None, **kwargs):
if x.dtype != tf.string:
x = tf.as_string(x, )
hash_x = tf.string_to_hash_bucket_fast(x, self.num_buckets if not self.mask_zero else self.num_buckets - 1,
name=None)#weak hash
name=None) # weak hash
if self.mask_zero:
mask_1 = tf.cast(tf.not_equal(x, "0"), 'int64')
mask_2 = tf.cast(tf.not_equal(x, "0.0"), 'int64')
Expand All @@ -60,6 +59,54 @@ def get_config(self, ):
return dict(list(base_config.items()) + list(config.items()))


class Linear(tf.keras.layers.Layer):

def __init__(self, l2_reg=0.0, mode=0, **kwargs):

self.l2_reg = l2_reg
# self.l2_reg = tf.contrib.layers.l2_regularizer(float(l2_reg_linear))
self.mode = mode
super(Linear, self).__init__(**kwargs)

def build(self, input_shape):

self.bias = self.add_weight(name='linear_bias',
shape=(1,),
initializer=tf.keras.initializers.Zeros(),
trainable=True)

self.dense = tf.keras.layers.Dense(units=1, activation=None, use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg))

super(Linear, self).build(input_shape) # Be sure to call this somewhere!

def call(self, inputs , **kwargs):

if self.mode == 0:
sparse_input = inputs
linear_logit = tf.reduce_sum(sparse_input, axis=-1, keep_dims=True)
elif self.mode == 1:
dense_input = inputs
linear_logit = self.dense(dense_input)

else:
sparse_input, dense_input = inputs

linear_logit = tf.reduce_sum(sparse_input, axis=-1, keep_dims=False) + self.dense(dense_input)

linear_bias_logit = linear_logit + self.bias

return linear_bias_logit

def compute_output_shape(self, input_shape):
return (None, 1)

def get_config(self, ):
config = {'mode': self.mode, 'l2_reg': self.l2_reg}
base_config = super(Linear, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


def concat_fun(inputs, axis=-1):
if len(inputs) == 1:
return inputs[0]
Expand Down
2 changes: 0 additions & 2 deletions docs/source/Examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,6 @@ There are 2 additional steps to use DeepCTR with sequence feature input.
- embedding : default `True`.If `False`, the feature will not be embeded to a dense vector.


Now multi-value input is avaliable for `AFM,AutoInt,DCN,DeepFM,FNN,NFM,PNN,xDeepFM,CCPM,FGCNN`,for `DIN,DIEN,DSIN` please read the example in [run_din.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_din.py),[run_dien.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_dien.py) and [run_dsin.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_dsin.py)

This example shows how to use ``DeepFM`` with sequence(multi-value) feature. You can get the demo data
[movielens_sample.txt](https://github.com/shenweichen/DeepCTR/tree/master/examples/movielens_sample.txt) and run the following codes.

Expand Down
66 changes: 50 additions & 16 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from tensorflow.python.keras.models import save_model,load_model
model = DeepFM()
save_model(model, 'DeepFM.h5')# save_model, same as before

from deepctr.utils import custom_objects
from deepctr.layers import custom_objects
model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter
```
## 2. Set learning rate and use earlystopping
Expand All @@ -30,7 +30,7 @@ import deepctr
from tensorflow.python.keras.optimizers import Adam,Adagrad
from tensorflow.python.keras.callbacks import EarlyStopping

model = deepctr.models.DeepFM({"sparse": sparse_feature_dict, "dense": dense_feature_list})
model = deepctr.models.DeepFM(linear_feature_columns,dnn_feature_columns)
model.compile(Adagrad('0.0808'),'binary_crossentropy',metrics=['binary_crossentropy'])

es = EarlyStopping(monitor='val_binary_crossentropy')
Expand All @@ -47,36 +47,70 @@ Then,use the following code,the `attentional_weights[:,i,0]` is the `feature_int
```python
import itertools
import deepctr
from deepctr.models import AFM
from deepctr.inputs import get_fixlen_feature_names,get_varlen_feature_names
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Lambda

feature_dim_dict = {"sparse": sparse_feature_dict, "dense": dense_feature_list}
model = deepctr.models.AFM(feature_dim_dict)
model = AFM(linear_feature_columns,dnn_feature_columns)
model.fit(model_input,target)

afmlayer = model.layers[-3]
afm_weight_model = Model(model.input,outputs=Lambda(lambda x:afmlayer.normalized_att_score)(model.input))
attentional_weights = afm_weight_model.predict(model_input,batch_size=4096)
feature_interactions = list(itertools.combinations(list(feature_dim_dict['sparse'].keys()) + feature_dim_dict['dense'] ,2))

fixlen_names = get_fixlen_feature_names( dnn_feature_columns)
varlen_names = get_varlen_feature_names(dnn_feature_columns)
feature_interactions = list(itertools.combinations(fixlen_names+varlen_names ,2))
```
## 4. How to extract the embedding vectors in deepfm?
```python
feature_columns = [SparseFeat('user_id',120,),SparseFeat('item_id',60,),SparseFeat('cate_id',60,)]

def get_embedding_weights(dnn_feature_columns,model):
embedding_dict = {}
for fc in dnn_feature_columns:
if hasattr(fc,'embedding_name'):
if fc.embedding_name is not None:
name = fc.embedding_name
else:
name = fc.name
embedding_dict[name] = model.get_layer("sparse_emb_"+name).get_weights()[0]
return embedding_dict

embedding_dict = get_embedding_weights(feature_columns,model)

user_id_emb = embedding_dict['user_id']
item_id_emb = embedding_dict['item_id']
```

## 4. Does the models support multi-value input?
---------------------------------------------------
Now multi-value input is avaliable for `AFM,AutoInt,DCN,DeepFM,FNN,NFM,PNN,xDeepFM`,you can read the example [here](./Examples.html#multi-value-input-movielens).
## 5. How to add a long dense feature vector as a input to the model?
```python
from deepctr.models import DeepFM
from deepctr.inputs import DenseFeat,SparseFeat,get_fixlen_feature_names
import numpy as np

For `DIN` please read the code example in [run_din.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_din.py
).
feature_columns = [SparseFeat('user_id',120,),SparseFeat('item_id',60,),DenseFeat("pic_vec",5)]
fixlen_feature_names = get_fixlen_feature_names(feature_columns)

For `DIEN` please read the code example in [run_dien.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_dien.py
).
user_id = np.array([[1],[0],[1]])
item_id = np.array([[30],[20],[10]])
pic_vec = np.array([[0.1,0.5,0.4,0.3,0.2],[0.1,0.5,0.4,0.3,0.2],[0.1,0.5,0.4,0.3,0.2]])
label = np.array([1,0,1])

You can also use layers in [sequence](./deepctr.layers.sequence.html)to build your own models !
input_dict = {'user_id':user_id,'item_id':item_id,'pic_vec':pic_vec}
model_input = [input_dict[name] for name in fixlen_feature_names]

## 5. How to add a long feature vector as a feature to the model?
please refer [this](https://github.com/shenweichen/DeepCTR/issues/42)
model = DeepFM(feature_columns,feature_columns[:-1])
model.compile('adagrad','binary_crossentropy')
model.fit(model_input,label)
```

## 6. How to run the demo with GPU ?
please refer [this](https://github.com/shenweichen/DeepCTR/issues/40)
just install deepctr with
```bash
$ pip install deepctr[gpu]
```

## 7. Could not find a version that satisfies the requirement deepctr (from versions)
please install with `pip3 install` instead of `pip install`
1 change: 1 addition & 0 deletions docs/source/History.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# History
- 07/21/2019 : [v0.5.2](https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.2) released.Refactor `Linear` Layer.
- 07/10/2019 : [v0.5.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.1) released.Add [FiBiNET](./Features.html#fibinet-feature-importance-and-bilinear-feature-interaction-network).
- 06/30/2019 : [v0.5.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.0) released.Refactor inputs module.
- 05/19/2019 : [v0.4.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.4.1) released.Add [DSIN](./Features.html#dsin-deep-session-interest-network).
Expand Down
6 changes: 4 additions & 2 deletions docs/source/Models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ DeepCTR Models API
DCN<deepctr.models.dcn>
DIN<deepctr.models.din>
DIEN<deepctr.models.dien>
DSIN<deepctr.models.dsin>
xDeepFM<deepctr.models.xdeepfm>
AutoInt<deepctr.models.autoint>
ONN<deepctr.models.onn>
NFFM<deepctr.models.nffm>
FGCNN<deepctr.models.fgcnn>
DSIN<deepctr.models.dsin>
FiBiNET<deepctr.models.fibinet>

2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.5.1'
release = '0.5.2'


# -- General configuration ---------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR

News
-----
07/21/2019 : Refactor Linear Layer. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.2>`_

07/10/2019 : Add `FiBiNEt <./Features.html#fibinet-feature-importance-and-bilinear-feature-interaction-network>`_ . `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.1>`_

06/30/2019 : Refactor inputs module. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.0>`_

05/19/2019 : Add `DSIN <./Features.html#dsin-deep-session-interest-network>`_ . `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.4.1>`_

.. toctree::
:maxdepth: 2
:caption: Home:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name="deepctr",
version="0.5.1",
version="0.5.2",
author="Weichen Shen",
author_email="[email protected]",
description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with tensorflow.",
Expand Down
Loading

0 comments on commit 8182ea3

Please sign in to comment.