Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add New method D3 #90

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions research/huawei-noah/D3/model/D3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import mindspore
import mindspore.nn as nn
from mindspore import ops
from mindspore import dtype as mstype
import numpy as np
from utils import MLP, DANet
from tqdm import tqdm

class DSFS(nn.Cell):
def __init__(self, feature_num, embed_size, hid_dim, output_dims):
super(DSFS, self).__init__(auto_prefix=True)
self.output_dims = output_dims
self.linear_layers = nn.CellList([MLP(embed_size, False, [hid_dim]) for i in range(feature_num)])
self.shared_weight = mindspore.Parameter(mindspore.Tensor(shape=(output_dims[0], output_dims[1]), dtype=mstype.float32))
self.shared_bias = mindspore.Parameter(mindspore.ops.zeros(output_dims[1]))
self.trans = MLP(feature_num * hid_dim, False, [feature_num * hid_dim], 0)
self.trans_weight = MLP(feature_num * hid_dim, False, [output_dims[0]*output_dims[1]], 0)
self.trans_bias = MLP(feature_num * hid_dim, False, [output_dims[1]], 0)

def construct(self, x):
b, f, e = x.shape
trans_features = []
for i in range(f):
feature = x[:, i, :].clone()
feature = self.linear_layers[i](feature)
trans_features.append(feature)
trans_features = mindspore.ops.stack(trans_features, dim=1)
residual_output = self.trans(trans_features.view(b,-1)) + trans_features.reshape(b, -1)
weight = self.trans_weight(residual_output).reshape(b, self.output_dims[0], self.output_dims[1])
bias = self.trans_bias(residual_output).reshape(b, self.output_dims[1])
return weight, bias

class D3(nn.Cell):
def __init__(self, feature_dims, dense_cols, sparse_cols, embed_dim=8, selected_ID_features=list(range(23)), hid_dim1=64, hid_dim2=32, mlp_dims=[32,32]):
super(D3, self).__init__(auto_prefix=True)
self.embedding = nn.Embedding(sum(feature_dims), embedding_dim=embed_dim)
self.embed_dim = embed_dim
self.dense_cols = dense_cols
self.sparse_cols = sparse_cols
self.offsets = np.array((0, *np.cumsum(feature_dims)[:-1]))
self.offsets = mindspore.Tensor(self.offsets, dtype=mstype.int32)
self.selected_ID_features = selected_ID_features
self.total_fea_num = len(dense_cols) + len(sparse_cols)
self.DSFS_tr = DSFS(len(self.selected_ID_features), embed_dim, hid_dim1, [self.total_fea_num * self.embed_dim, hid_dim2])
self.output_mlp = MLP(hid_dim2, True, mlp_dims, 0)
self.add_attention = DANet(self.embed_dim)
self.gate = MLP(self.total_fea_num * self.embed_dim, False, [64,16,2], 0)
self.desired_std = 0.5
self.desired_mean = 0.5
self.alpha_1 = 1.0
self.alpha_2 = 1.0
self.lambda_1 = 0.5
self.lambda_2 = 0.5

def construct(self, sparse):
b,f = sparse.shape
slot_id = sparse[:, 18].clone()
sparse = sparse + self.offsets.unsqueeze(0)
sparse = self.embedding(sparse)
DSFS_input = sparse.clone()
DSFS_input, attn_weights = self.add_attention(DSFS_input)
attn_weights = attn_weights
attn_weights = mindspore.ops.mean(attn_weights, axis=1)
log_attn_weights = mindspore.ops.log(attn_weights + 1e-9) # b,f*f
entropy = -mindspore.ops.sum(attn_weights * log_attn_weights, dim=-1).reshape(b)
np_lis = entropy.asnumpy()
mean = np_lis.mean()
std = np_lis.std()
arr = (np_lis - mean) / std
desired_std = self.desired_std
desired_mean = self.desired_mean
arr = arr * desired_std + desired_mean
arr = np.clip(arr, 0.1, 1)
loss_weight = mindspore.Tensor(arr, dtype=mstype.float32).reshape(b, 1)
gate = mindspore.ops.Softmax(self.gate(sparse.reshape(b, -1)), dim=1) # b, 2
return_gate = gate
se_output = sparse
tr_weight, tr_bias = self.DSFS_tr(DSFS_input) # b, f*e, hid[1] ; b, hid[1]
tr_weight = mindspore.ops.multiply(mindspore.ops.multiply(tr_weight, gate[:,0].reshape(b,1,1)) , mindspore.ops.multiply(self.DSFS_tr.shared_weight.unsqueeze(0), gate[:,1].reshape(b,1,1)))
tr_output = mindspore.ops.matmul(se_output.reshape(b, 1, -1), tr_weight).reshape(b, -1) + tr_bias # b,1,f*e * b,f*e,hid[1] => b,hid[1]
output = mindspore.ops.sigmoid(self.output_mlp(tr_output)) # b, 1
return output, self.alpha_1 * loss_weight + self.lambda_1, self.alpha_2 * return_gate[:,0].reshape(b,1) + self.lambda_2

def train_(self, args, train_dataloader, valid_dataloader, optimizer, criterion, epoch=0):
def forward_fn(data, labels):
output, loss_weight, gate_weight = self(data)
loss = criterion(output, labels.float()) # Assuming target needs to be float
if args.start_weight_loss_step and batch_idx > args.start_weight_loss_step:
loss = ops.multiply(loss, loss_weight)
loss = ops.multiply(loss, gate_weight)
loss = ops.mean(loss)
return loss, output
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
self.set_train(True)
total_loss = 0.0
total_length = 0
tk0 = tqdm(train_dataloader, desc=f"train epoch {epoch}", leave=True, colour='blue')
for batch_idx, (data) in enumerate(tk0):
sparse, target = data['features'], data['labels']
(loss, _), grads = grad_fn(sparse, target)
optimizer(grads)
total_loss += loss.asnumpy() * len(target)
total_length += len(target)
if (batch_idx + 1) % 100 == 0: # Example log interval of 100
tk0.set_postfix(train_loss=total_loss / total_length)
total_loss = 0.0 # Reset total loss after logging

def test_(self, args, test_dataloader, auc, log_loss):
self.set_train(False)
y_true, y_pred, slot_id = [], [], []
tk0 = tqdm(test_dataloader, desc=f"test ", leave=True, colour='blue')
for batch_idx, (data) in enumerate(tk0):
sparse, target = data['features'], data['labels']
scene_id = sparse[:, 18]
output, _, _ = self(sparse)
y_true.extend(target.tolist())
y_pred.extend(output.tolist())
slot_id.extend(scene_id.tolist())
y_true, y_pred, slot_id = np.array(y_true).flatten(), np.array(y_pred), np.array(slot_id)
for i in range(1,4):
tqdm.write(f"Slot {i} AUC: {auc(y_true[slot_id==i], y_pred[slot_id==i])}")
tqdm.write(f"Slot {i} Log Loss: {log_loss(y_true[slot_id==i], y_pred[slot_id==i])}")
39 changes: 39 additions & 0 deletions research/huawei-noah/D3/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# D3

This repository contains the source code for our paper: "D3: A Methodological Exploration of Domain Division, Modeling, and Balance in Multi-Domain Recommendations".

## Environment

```
mindspore
scikit-learn
pandas
tqdm
```

## BibTex

```
@inproceedings{jia2024d3,
title={D3: A Methodological Exploration of Domain Division, Modeling, and Balance in Multi-Domain Recommendations},
author={Jia, Pengyue and Wang, Yichao and Lin, Shanru and Li, Xiaopeng and Zhao, Xiangyu and Guo, Huifeng and Tang, Ruiming},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={38},
number={8},
pages={8553--8561},
year={2024}
}
```

## Usage

This code is used for aliccp dataset, for the other datasets, please modify the corresponding slot_id index in `run.py` and `D3.py`.

run the `run.py` file. This can be done by navigating to the project directory and executing the following command in the terminal.

```bash
python run.py
```

Ensure that all the dependencies are installed and the dataset is put under path `./data/` before running the project.

75 changes: 75 additions & 0 deletions research/huawei-noah/D3/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np
import pandas as pd
import mindspore
import gc
from tqdm import tqdm
from model.D3 import D3
import mindspore.dataset as ds
from mindspore import dtype as mstype
from sklearn.metrics import roc_auc_score, log_loss
from sklearn.preprocessing import LabelEncoder
from mindspore import nn, Model, context

def get_ali_ccp_data_dict_pd(args, data_path='./data/'):
data_type = {'click':np.int8, 'purchase': np.int8, '101':np.int32, '121':np.uint8, '122':np.uint8, '124':np.uint8, '125':np.uint8, '126':np.uint8, '127':np.uint8, '128':np.uint8, '129':np.uint8, '205':np.int32, '206':np.int16, '207':np.int32, '210':np.int32, '216':np.int32, '508':np.int16, '509':np.int32, '702':np.int32, '853':np.int32, '301':np.int8, '109_14':np.int16, '110_14':np.int32, '127_14':np.int32, '150_14':np.int32, 'D109_14': np.float16, 'D110_14': np.float16, 'D127_14': np.float16, 'D150_14': np.float16, 'D508': np.float16, 'D509': np.float16, 'D702': np.float16, 'D853': np.float16}

df_train = pd.read_csv(data_path + '/ali_ccp_train.csv', dtype=data_type)
df_val = pd.read_csv(data_path + '/ali_ccp_val.csv', dtype=data_type)
df_test = pd.read_csv(data_path + '/ali_ccp_test.csv', dtype=data_type)

print("train : val : test = %d %d %d" % (df_train.shape[0], df_val.shape[0], df_test.shape[0]))
lengths = [df_train.shape[0], df_val.shape[0], df_test.shape[0]]
train_idx = lengths[0]
data = pd.concat([df_train, df_val, df_test], axis=0)
print(data.head(5))
del df_train, df_val, df_test
col_names = data.columns
dense_cols = ['D109_14', 'D110_14', 'D127_14', 'D150_14', 'D508', 'D509', 'D702', 'D853']
sparse_cols = [col for col in col_names if col not in dense_cols and col not in ['click', 'purchase']]
print('dense cols:', dense_cols, 'sparse cols:', sparse_cols)
print("sparse cols:%d dense cols:%d" % (len(sparse_cols), len(dense_cols)))
y = data.loc[:,"click"]
sparse_x = data[sparse_cols]
sparse_x_unique = [238635, 98, 14, 3, 8, 4, 4, 3, 5, 467298, 6929, 263942, 80232, 106399, 5888, 104830, 51878, 37148, 4, 5853, 105622, 53843, 31858]
x_sparse_train, x_sparse_test = sparse_x[:train_idx], sparse_x[train_idx:]
y_train, y_test = y[:train_idx], y[train_idx:]
x_sparse_train, y_train = mindspore.Tensor(x_sparse_train.values), mindspore.Tensor(y_train.values, dtype=mstype.float32).reshape(-1,1)
x_sparse_test, y_test = mindspore.Tensor(x_sparse_test.values), mindspore.Tensor(y_test.values, dtype=mstype.float32).reshape(-1,1)
sampler = ds.SequentialSampler()
train_dataset = ds.NumpySlicesDataset({'features': x_sparse_train, 'labels': y_train}, sampler=sampler)
test_dataset = ds.NumpySlicesDataset({'features': x_sparse_test, 'labels': y_test}, sampler=sampler)
train_dataset = train_dataset.batch(batch_size=args.batch_size)
train_dataloader = train_dataset.create_tuple_iterator()
test_dataset = test_dataset.batch(batch_size=args.batch_size)
test_dataloader = test_dataset.create_tuple_iterator()
return train_dataloader, None, test_dataloader, sparse_x_unique, dense_cols, sparse_cols, lengths

def main(args):
context.set_context(device_target=args.device)
train_dataloader, valid_dataloader, test_dataloader, feature_dims, dense_cols, sparse_cols, lengths = get_ali_ccp_data_dict_pd(args, args.data_path)
model = D3(feature_dims, [], sparse_cols, args.embed_size, selected_ID_features = list(range(len(sparse_cols))), hid_dim1=args.d3_hid_dim1, hid_dim2=args.d3_hid_dim2, mlp_dims=args.d3_mlp_dims)
optimizer = mindspore.nn.AdamWeightDecay(model.parameters(), learning_rate=args.lr, weight_decay=1e-5)
criterion = mindspore.nn.BCELoss(reduction='none')
for i in range(args.epochs):
model.train_(args, train_dataloader, valid_dataloader, optimizer, criterion, i)
model.test_(args, test_dataloader, roc_auc_score, log_loss)


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='./data/')
parser.add_argument('--batch_size', type=int, default=2048)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--embed_size', type=int, default=16)
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--d3_hid_dim1', type=int, default=64)
parser.add_argument('--d3_hid_dim2', type=int, default=128)
parser.add_argument('--d3_mlp_dims',type=list, default=[16,16])
parser.add_argument('--selected_ID_features', type=list, default=list(range(23)))
parser.add_argument('--start_weight_loss_step', type=int, default=15000)
args = parser.parse_args()

main(args)
56 changes: 56 additions & 0 deletions research/huawei-noah/D3/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import random
import numpy as np
import mindspore
import mindspore.nn as nn
import time
import pandas as pd
import gc
import mindspore.dataset as ds
from mindspore import dtype as mstype

class MLP(nn.Cell):
def __init__(self, input_dim, output_layer=True, dims=None, dropout=0):
super().__init__()
if dims is None:
dims = []
layers = list()
for i_dim in dims:
layers.append(nn.Dense(input_dim, i_dim))
layers.append(nn.BatchNorm1d(i_dim))
layers.append(nn.LeakyReLU())
layers.append(nn.Dropout(p=dropout))
input_dim = i_dim
if output_layer:
layers.append(nn.Dense(input_dim, 1))
self.mlp = mindspore.nn.SequentialCell(*layers)

def construct(self, x):
return self.mlp(x)

class channel_attn(nn.Cell):
def __init__(self, emb) -> None:
super().__init__()
self.q_linear = nn.Dense(emb, emb)
self.k_linear = nn.Dense(emb, emb)
self.v_linear = nn.Dense(emb, emb)

def construct(self, x):
query = self.q_linear(x)
key = self.k_linear(x)
value = self.v_linear(x)
key = key.permute(0, 2, 1)
attn_weight = mindspore.ops.matmul(query, key) # b,f,f
attn_weight = mindspore.ops.Softmax(attn_weight, dim=-1)
out = mindspore.ops.matmul(attn_weight, value) + value # b,f,e
return out, attn_weight

class DANet(nn.Cell):
def __init__(self, emb):
super().__init__()
self.channel_attn = channel_attn(emb)

def construct(self, x):
channel_out, channel_attn_weight = self.channel_attn(x)
out = channel_out
return out, channel_attn_weight