Skip to content

Commit 1a04249

Browse files
committed
init
1 parent 654ccea commit 1a04249

File tree

12 files changed

+1714
-0
lines changed

12 files changed

+1714
-0
lines changed

.gitignore

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Byte-compiled / optimized / DLL files
2+
*.py[cod]
3+
*$py.class
4+
5+
# C extensions
6+
*.so
7+
8+
# Distribution / packaging
9+
.Python
10+
build/
11+
develop-eggs/
12+
dist/
13+
downloads/
14+
eggs/
15+
.eggs/
16+
lib/
17+
lib64/
18+
parts/
19+
sdist/
20+
var/
21+
wheels/
22+
*.egg-info/
23+
.installed.cfg
24+
*.egg
25+
MANIFEST
26+
27+
# PyInstaller
28+
# Usually these files are written by a python script from a template
29+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
30+
*.manifest
31+
*.spec
32+
33+
# Installer logs
34+
pip-log.txt
35+
pip-delete-this-directory.txt
36+
37+
# Unit test / coverage reports
38+
htmlcov/
39+
.tox/
40+
.coverage
41+
.coverage.*
42+
.cache
43+
nosetests.xml
44+
coverage.xml
45+
*.cover
46+
.hypothesis/
47+
48+
# Translations
49+
*.mo
50+
*.pot
51+
52+
# Django stuff:
53+
*.log
54+
.static_storage/
55+
.media/
56+
local_settings.py
57+
58+
# Flask stuff:
59+
instance/
60+
.webassets-cache
61+
62+
# Scrapy stuff:
63+
.scrapy
64+
65+
# Sphinx documentation
66+
docs/_build/
67+
68+
# PyBuilder
69+
target/
70+
71+
# Jupyter Notebook
72+
.ipynb_checkpoints
73+
74+
# pyenv
75+
.python-version
76+
77+
# celery beat schedule file
78+
celerybeat-schedule
79+
80+
# SageMath parsed files
81+
*.sage.py
82+
83+
# Environments
84+
.env
85+
.venv
86+
env/
87+
venv/
88+
ENV/
89+
env.bak/
90+
venv.bak/
91+
92+
# Spyder project settings
93+
.spyderproject
94+
.spyproject
95+
96+
# Rope project settings
97+
.ropeproject
98+
99+
# mkdocs documentation
100+
/site
101+
102+
# mypy
103+
.mypy_cache/
104+
105+
#
106+
.idea
107+
__pycache__
108+
analysis.ipynb
109+
dp.pkl
110+
tmp.py
111+
weights/*

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2018 Chenglong Chen
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

__init__.py

Whitespace-only changes.

config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
3+
DATA_DIR = "../data"
4+
5+
TRAIN_FILE = DATA_DIR + "/atec_nlp_sim_train_all.csv"

main.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
2+
import config
3+
import sys
4+
import numpy as np
5+
import pandas as pd
6+
import pickle as pkl
7+
8+
from keras.preprocessing.sequence import pad_sequences
9+
10+
import utils
11+
from preprocessor import DataProcessor
12+
from model import SemanticMatchingModel
13+
14+
15+
def get_model_data(dataset, params):
16+
17+
X = {}
18+
X['id'] = dataset['id'].values
19+
X["label"] = dataset['label'].values
20+
21+
# word level
22+
X['seq_word_left'] = pad_sequences(dataset.seq_word_left, maxlen=params["max_sequence_length_word"],
23+
padding=params["pad_sequences_padding"],
24+
truncating=params["pad_sequences_truncating"])
25+
X["sequence_length_word"] = params["max_sequence_length_word"] * np.ones(dataset.shape[0])
26+
27+
X['seq_word_right'] = pad_sequences(dataset.seq_word_right, maxlen=params["max_sequence_length_word"],
28+
padding=params["pad_sequences_padding"],
29+
truncating=params["pad_sequences_truncating"])
30+
X["sequence_length_word"] = params["max_sequence_length_word"] * np.ones(dataset.shape[0])
31+
32+
# char level
33+
X['seq_char_left'] = pad_sequences(dataset.seq_char_left, maxlen=params["max_sequence_length_char"],
34+
padding=params["pad_sequences_padding"],
35+
truncating=params["pad_sequences_truncating"])
36+
X["sequence_length_char"] = params["max_sequence_length_char"] * np.ones(dataset.shape[0])
37+
38+
X['seq_char_right'] = pad_sequences(dataset.seq_char_right, maxlen=params["max_sequence_length_char"],
39+
padding=params["pad_sequences_padding"],
40+
truncating=params["pad_sequences_truncating"])
41+
X["sequence_length_char"] = params["max_sequence_length_char"] * np.ones(dataset.shape[0])
42+
43+
return X
44+
45+
params = {
46+
"offline_model_dir": "./weights/semantic_matching",
47+
"batch_size": 32,
48+
"epoch": 5,
49+
"l2_lambda": 0.0001,
50+
51+
"embedding_dropout": 0.2,
52+
"embedding_word_dim": 128,
53+
"embedding_char_dim": 128,
54+
"embedding_dim": 128,
55+
56+
"max_num_words": 10000,
57+
"max_num_chars": 10000,
58+
59+
"threshold": 0.217277,
60+
61+
"max_sequence_length_word": 20,
62+
"max_sequence_length_char": 30,
63+
"pad_sequences_padding": "post",
64+
"pad_sequences_truncating": "post",
65+
66+
"optimizer_type": "nadam",
67+
"init_lr": 0.001,
68+
"beta1": 0.975,
69+
"beta2": 0.999,
70+
"decay_steps": 500,
71+
"decay_rate": 0.95,
72+
"schedule_decay": 0.004,
73+
"random_seed": 2018,
74+
"eval_every_num_update": 100,
75+
76+
"encode_method": "fasttext",
77+
"attend_method": "attention",
78+
79+
"cnn_num_filters": 32,
80+
"cnn_filter_sizes": [1, 2, 3],
81+
"cnn_timedistributed": False,
82+
83+
"rnn_num_units": 20,
84+
"rnn_cell_type": "gru",
85+
86+
# fc block
87+
"fc_type": "fc",
88+
"fc_dim": 64,
89+
"fc_dropout": 0,
90+
}
91+
92+
model_name = "semantic_matching"
93+
94+
def train():
95+
96+
utils._makedirs("../logs")
97+
utils._makedirs("../output")
98+
logger = utils._get_logger("../logs", "tf-%s.log" % utils._timestamp())
99+
100+
101+
dfTrain = pd.read_csv(config.TRAIN_FILE, header=None, sep="\t")
102+
dfTrain.columns = ["id", "left", "right", "label"]
103+
104+
dfTrain.dropna(inplace=True)
105+
106+
# shuffle training data
107+
dfTrain = dfTrain.sample(frac=1.0)
108+
109+
dp = DataProcessor(max_num_words=params["max_num_words"], max_num_chars=params["max_num_chars"])
110+
dfTrain = dp.fit_transform(dfTrain)
111+
112+
N = dfTrain.shape[0]
113+
train_ratio = 0.6
114+
train_num = int(N*train_ratio)
115+
X_train = get_model_data(dfTrain[:train_num], params)
116+
X_valid = get_model_data(dfTrain[train_num:], params)
117+
118+
model = SemanticMatchingModel(model_name, params, logger=logger, threshold=0.2)
119+
model.fit(X_train, validation_data=X_valid, shuffle=False)
120+
121+
# save model
122+
model.save_session()
123+
with open("dp.pkl", "wb") as f:
124+
pkl.dump((dp, model.threshold), f, protocol=2)
125+
126+
127+
def submit(input_file, output_file):
128+
129+
print("read %s"%input_file)
130+
print("write %s"%output_file)
131+
132+
# load model
133+
with open("dp.pkl", "rb") as f:
134+
dp, threshold = pkl.load(f)
135+
model = SemanticMatchingModel(model_name, params, logger=None, threshold=threshold, training=False)
136+
model.restore_session()
137+
138+
dfTest = pd.read_csv(input_file, header=None, sep="\t")
139+
dfTest.columns = ["id", "left", "right"]
140+
dfTest["label"] = np.zeros(dfTest.shape[0])
141+
142+
dfTest = dp.transform(dfTest)
143+
X_test = get_model_data(dfTest, params)
144+
145+
dfTest["label"] = model.predict(X_test)
146+
147+
dfTest[["id", "label"]].to_csv(output_file, header=False, index=False, sep="\t")
148+
149+
150+
if __name__ == "__main__":
151+
if len(sys.argv) > 2:
152+
submit(sys.argv[1], sys.argv[2])
153+
else:
154+
train()

0 commit comments

Comments
 (0)