Skip to content

Commit 2da48ed

Browse files
committed
improve web-interface and parameter-helper
1 parent 01ad157 commit 2da48ed

File tree

5 files changed

+56
-27
lines changed

5 files changed

+56
-27
lines changed

Parsing/parser_utils.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ def buildDataset(path_file: str, verbose=True) -> EntityHandler:
140140
sentences, list_of_labels = [], []
141141
set_entities = set() # set of unique entity found (incrementally updated)
142142

143-
for field in read_conll(path_file): # generator
143+
for fields in read_conll(path_file): # generator
144144

145-
tokens, labels = field[0], field[1]
145+
tokens, labels = fields[0], fields[1]
146146

147147
sentences.append(" ".join(tokens))
148148
list_of_labels.append(" ".join(labels))
@@ -206,22 +206,33 @@ def random_chars(y):
206206

207207

208208
def parse_args():
209-
p = argparse.ArgumentParser(description='Model configuration.', add_help=False)
209+
p = argparse.ArgumentParser(description='Model configuration.', add_help=True)
210210

211-
p.add_argument('--datasets', type=str, nargs='+', help='Path to the datasets', default=None)
212-
p.add_argument('--models', type=str, nargs='+', help='Models in the same order of datasets', default=None)
211+
p.add_argument('--datasets', type=str, nargs='+',
212+
help='Dataset used for training, it will split in training, validation and test', default=None)
213213

214-
p.add_argument('--model_name', type=str, help='Name of trained model', default=None)
215-
p.add_argument('--path_model', type=str, help='Directory to save the model', default=".")
214+
p.add_argument('--models', type=str, nargs='+',
215+
help='Model trained ready to evaluate or use, if list, the order must follow the same of datasets',
216+
default=None)
216217

217-
p.add_argument('--bert', type=str, help='Huggingface model', default="dbmdz/bert-base-italian-xxl-cased")
218+
p.add_argument('--model_name', type=str,
219+
help='Name to give to a trained model', default=None)
220+
221+
p.add_argument('--path_model', type=str,
222+
help='Directory to save the model', default=".")
223+
224+
p.add_argument('--bert', type=str,
225+
help='Bert model provided by Huggingface', default="dbmdz/bert-base-italian-xxl-cased")
226+
227+
p.add_argument('--save_model', type=int,
228+
help='set 1 if you want save the model otherwise set 0', default=1)
218229

219230
p.add_argument('--lr', type=float, help='Learning rate', default=0.010)
220231
p.add_argument('--momentum', type=float, help='Momentum', default=0.9)
221232
p.add_argument('--weight_decay', type=float, help='Weight decay', default=0.0002)
222233
p.add_argument('--batch_size', type=int, help='Batch size', default=2)
223234
p.add_argument('--max_epoch', type=int, help='Max number of epochs', default=20)
224235
p.add_argument('--early_stopping', type=float, help='Patience in early stopping', default=3)
225-
p.add_argument('--save_model', type=int, help='1 to save the model', default=1)
236+
226237

227238
return p.parse_known_args()

Prediction/Predictor.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
from torch import IntTensor, BoolTensor, masked_select
24
from transformers import BertTokenizerFast
35

@@ -61,7 +63,7 @@ def add_model(self, group: str, model: NERClassifier, dictionary: dict):
6163
model.eval()
6264
self.models[group] = (model, dictionary)
6365

64-
def predict(self, string: str) -> list:
66+
def predict(self, string: str) -> Tuple[list, list]:
6567

6668
token_text = self.tokenizer(string)
6769

@@ -84,4 +86,13 @@ def predict(self, string: str) -> list:
8486
[lbl[2:] if lbl != "O" else "O" for lbl in self.map_id2lab(dictionary, logits)])
8587

8688
results = self.unify_labels(results[0], results[1]) if len(results) == 2 else results[0]
87-
return results
89+
90+
# Mask is used to show only a once the entity. if true on the last word in a group of words
91+
# where it was detected as entity
92+
mask = [False] * len(results)
93+
for idx in range(len(results) - 1):
94+
if results[idx] != results[idx + 1] and results[idx] != "":
95+
mask[idx] = True
96+
mask[-1] = True if results[-1] != "" else False
97+
98+
return results, mask

server.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,24 @@
3535
predictor.add_model("a", modelA, id2lab_group_a)
3636
predictor.add_model("b", modelB, id2lab_group_b)
3737

38+
list_of_result = []
39+
3840

3941
@app.route('/', methods=('GET', 'POST'))
4042
def create():
4143
if request.method == 'POST':
44+
4245
sentence = request.form['Sentence']
43-
tag_pred = predictor.predict(sentence)
4446

45-
mask = [False] * len(tag_pred)
46-
for idx in range(len(tag_pred)-1):
47-
if tag_pred[idx] != tag_pred[idx+1] and tag_pred[idx] != "":
48-
mask[idx] = True
47+
if "predict" in request.form and sentence != "":
48+
tag_pred, mask = predictor.predict(sentence)
49+
result_ = [*zip(sentence.split(), tag_pred, mask)]
50+
list_of_result.append(result_)
51+
52+
elif "clear" in request.form:
53+
list_of_result.clear()
4954

50-
result_ = [*zip(sentence.split(), tag_pred, mask)]
51-
else:
52-
result_ = []
53-
return render_template('main.html', result=result_)
55+
return render_template('main.html', list_of_result=list_of_result)
5456

5557

5658
"""

templates/main.html

+10-5
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,22 @@
6464
<form method="POST">
6565
<div style="padding:10px">
6666
<label class="form-label"><b>Sentence</b></label>
67-
<textarea type="text" style="padding:8px;" class="form-control" name="Sentence" value=""></textarea>
67+
<textarea style="padding:8px;" class="form-control" name="Sentence"></textarea>
6868
</div>
69-
<button type="submit" class="btn" style="background-color:#AD8E70"><b>Entity extraction</b></button>
69+
<button type="submit" class="btn" name="predict" style="background-color:#AD8E70"><b>Entity extraction</b></button>
70+
<button type="submit" class="btn" name="clear" style="background-color:#AD8E70"><b>Clear</b></button>
7071
</form>
7172
</div>
7273
</div>
7374
<div class="grid-item3 text_result">
7475
<br>
75-
{% for (token, tag, mask) in result %}
76-
{% if tag %}<i><b>{{ token }}</b></i>{% else %}{{ token }}{% endif %}
77-
{% if tag and mask %}<sub style="color:red"> ({{ tag }})</sub> {% endif %}
76+
{% for result in list_of_result %}
77+
{% for (token, tag, mask) in result %}
78+
{% if tag %}<i><b>{{ token }}</b></i>{% else %}{{ token }}{% endif %}
79+
{% if tag and mask %}<sub style="color:red"> ({{ tag }})</sub> {% endif %}
80+
{% endfor %}
81+
<br>
82+
<hr style="width:100%">
7883
{% endfor %}
7984
</div>
8085
<div class="grid-item2">

train_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919
handler = buildDataset(args.datasets[0], verbose=True)
2020
df_train, df_val, df_test = holdout(handler.dt)
21+
2122
model = NERClassifier(conf.bert, len(handler.set_entities), frozen=False)
22-
# model.load_state_dict(torch.load(conf.folder + "tmp/modelA2.pt"))
2323

2424
if conf.cuda:
2525
model = model.to(conf.gpu)
2626

2727
train(model, handler, df_train, df_val, conf)
2828

2929
"""
30-
C:\ProgramData\Anaconda3\envs\deeplearning\python.exe train_model.py --model_name prova.pt --max_epoch 1 --datasets .\Source\dataset.a.conll
30+
C:\ProgramData\Anaconda3\envs\deeplearning\python.exe train_model.py --model_name prova --max_epoch 1 --datasets .\Source\dataset.a.conll
3131
"""

0 commit comments

Comments
 (0)