|
3 | 3 |
|
4 | 4 | Authors:
|
5 | 5 |
|
6 |
| - Harshit Pande |
| 6 | + Harshit Pande, |
| 7 | + |
7 | 8 |
|
8 | 9 | """
|
9 | 10 |
|
|
26 | 27 |
|
27 | 28 | from .activation import activation_layer
|
28 | 29 | from .utils import concat_func, reduce_sum, softmax, reduce_mean
|
| 30 | +from .core import DNN |
29 | 31 |
|
30 | 32 |
|
31 | 33 | class AFMLayer(Layer):
|
@@ -1489,3 +1491,69 @@ def get_config(self):
|
1489 | 1491 | 'regularizer': self.regularizer,
|
1490 | 1492 | })
|
1491 | 1493 | return config
|
| 1494 | + |
| 1495 | + |
| 1496 | +class BridgeModule(Layer): |
| 1497 | + """Bridge Module used in EDCN |
| 1498 | +
|
| 1499 | + Input shape |
| 1500 | + - A list of two 2D tensor with shape: ``(batch_size, units)``. |
| 1501 | +
|
| 1502 | + Output shape |
| 1503 | + - 2D tensor with shape: ``(batch_size, units)``. |
| 1504 | +
|
| 1505 | + Arguments |
| 1506 | + - **bridge_type**: The type of bridge interaction, one of 'pointwise_addition', 'hadamard_product', 'concatenation', 'attention_pooling' |
| 1507 | +
|
| 1508 | + - **activation**: Activation function to use. |
| 1509 | +
|
| 1510 | + References |
| 1511 | + - [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf) |
| 1512 | +
|
| 1513 | + """ |
| 1514 | + |
| 1515 | + def __init__(self, bridge_type='hadamard_product', activation='relu', **kwargs): |
| 1516 | + self.bridge_type = bridge_type |
| 1517 | + self.activation = activation |
| 1518 | + |
| 1519 | + super(BridgeModule, self).__init__(**kwargs) |
| 1520 | + |
| 1521 | + def build(self, input_shape): |
| 1522 | + if not isinstance(input_shape, list) or len(input_shape) < 2: |
| 1523 | + raise ValueError( |
| 1524 | + 'A `BridgeModule` layer should be called ' |
| 1525 | + 'on a list of 2 inputs') |
| 1526 | + |
| 1527 | + self.dnn_dim = int(input_shape[0][-1]) |
| 1528 | + if self.bridge_type == "concatenation": |
| 1529 | + self.dense = Dense(self.dnn_dim, self.activation) |
| 1530 | + elif self.bridge_type == "attention_pooling": |
| 1531 | + self.dense_x = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax') |
| 1532 | + self.dense_h = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax') |
| 1533 | + |
| 1534 | + super(BridgeModule, self).build(input_shape) # Be sure to call this somewhere! |
| 1535 | + |
| 1536 | + def call(self, inputs, **kwargs): |
| 1537 | + x, h = inputs |
| 1538 | + if self.bridge_type == "pointwise_addition": |
| 1539 | + return x + h |
| 1540 | + elif self.bridge_type == "hadamard_product": |
| 1541 | + return x * h |
| 1542 | + elif self.bridge_type == "concatenation": |
| 1543 | + return self.dense(tf.concat([x, h], axis=-1)) |
| 1544 | + elif self.bridge_type == "attention_pooling": |
| 1545 | + a_x = self.dense_x(x) |
| 1546 | + a_h = self.dense_h(h) |
| 1547 | + return a_x * x + a_h * h |
| 1548 | + |
| 1549 | + def compute_output_shape(self, input_shape): |
| 1550 | + return (None, self.dnn_dim) |
| 1551 | + |
| 1552 | + def get_config(self): |
| 1553 | + base_config = super(BridgeModule, self).get_config().copy() |
| 1554 | + config = { |
| 1555 | + 'bridge_type': self.bridge_type, |
| 1556 | + 'activation': self.activation |
| 1557 | + } |
| 1558 | + config.update(base_config) |
| 1559 | + return config |
0 commit comments