Skip to content

Commit eb16626

Browse files
authored
Merge pull request #134 from stephenhky/develop
10x faster of classes
2 parents 02b1056 + 349dbd8 commit eb16626

File tree

6 files changed

+31
-13
lines changed

6 files changed

+31
-13
lines changed

.circleci/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,22 @@ jobs:
3636
py37:
3737
<<: *shared
3838
docker:
39-
- image: circleci/python:3.7
39+
- image: cimg/python:3.7
4040

4141
py38:
4242
<<: *shared
4343
docker:
44-
- image: circleci/python:3.8
44+
- image: cimg/python:3.8
4545

4646
py39:
4747
<<: *shared
4848
docker:
49-
- image: circleci/python:3.9
49+
- image: cimg/python:3.9
5050

5151
py310:
5252
<<: *shared
5353
docker:
54-
- image: circleci/python:3.10
54+
- image: cimg/python:3.10
5555

5656

5757
workflows:

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ If you would like to contribute, feel free to submit the pull requests. You can
9292

9393
## News
9494

95+
* 08/29/2022: `shorttext` 1.5.6 released.
96+
* 05/28/2022: `shorttext` 1.5.5 released.
9597
* 12/15/2021: `shorttext` 1.5.4 released.
9698
* 07/11/2021: `shorttext` 1.5.3 released.
9799
* 07/06/2021: `shorttext` 1.5.2 released.

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
# The short X.Y version.
5959
version = u'1.5'
6060
# The full version, including alpha/beta/rc tags.
61-
release = u'1.5.5'
61+
release = u'1.5.6'
6262

6363
# The language for content autogenerated by Sphinx. Refer to documentation
6464
# for a list of supported languages.

docs/news.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
News
22
====
33

4+
* 08/29/2022: `shorttext` 1.5.6 released.
45
* 05/28/2022: `shorttext` 1.5.5 released.
56
* 12/15/2021: `shorttext` 1.5.4 released.
67
* 07/11/2021: `shorttext` 1.5.3 released.
@@ -74,6 +75,11 @@ News
7475
What's New
7576
----------
7677

78+
Release 1.5.6 (August 29, 2022)
79+
-------------------------------
80+
81+
* Speeding up inference of `VarNNEmbeddedVecClassifier`. (Acknowledgement: Ritesh Agrawal)
82+
7783
Release 1.5.5 (May 28, 2022)
7884
-----------------------------
7985

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_requirements():
3535

3636

3737
setup(name='shorttext',
38-
version='1.5.5',
38+
version='1.5.6a1',
3939
description="Short Text Mining",
4040
long_description=package_description(),
4141
long_description_content_type='text/markdown',

shorttext/classifiers/embed/nnlib/VarNNEmbedVecClassification.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import warnings
55

66
import numpy as np
7+
import pandas as pd
78

89
import shorttext.utils.kerasmodel_io as kerasio
910
import shorttext.utils.classification_exceptions as e
1011
from shorttext.utils import tokenize
1112
from shorttext.utils.compactmodel_io import CompactIOMachine
13+
from typing import Union, List, Dict, Any
1214

1315

1416
class VarNNEmbeddedVecClassifier(CompactIOMachine):
@@ -208,7 +210,7 @@ def shorttext_to_matrix(self, shorttext):
208210
matrix[i] = self.word_to_embedvec(tokens[i])
209211
return matrix
210212

211-
def score(self, shorttext):
213+
def score(self, shorttexts: Union[str, List[str]], model_params: Dict[str, Any] = {}):
212214
""" Calculate the scores for all the class labels for the given short sentence.
213215
214216
Given a short sentence, calculate the classification scores for all class labels,
@@ -217,25 +219,33 @@ def score(self, shorttext):
217219
If neither :func:`~train` nor :func:`~loadmodel` was run, it will raise `ModelNotTrainedException`.
218220
219221
:param shorttext: a short sentence
222+
:param model_params: additional parameters to be passed to the model object
220223
:return: a dictionary with keys being the class labels, and values being the corresponding classification scores
221224
:type shorttext: str
222225
:rtype: dict
223226
:raise: ModelNotTrainedException
224227
"""
228+
is_multiple = True
229+
if isinstance(shorttexts, str):
230+
is_multiple = False
231+
shorttexts = [shorttexts]
232+
225233
if not self.trained:
226234
raise e.ModelNotTrainedException()
227235

228236
# retrieve vector
229-
matrix = np.array([self.shorttext_to_matrix(shorttext)])
237+
matrix = np.array([self.shorttext_to_matrix(shorttext) for shorttext in shorttexts])
230238

231239
# classification using the neural network
232-
predictions = self.model.predict(matrix)
240+
predictions = self.model.predict(matrix, **model_params)
233241

234242
# wrangle output result
235-
scoredict = {classlabel: predictions[0][idx]
236-
for idx, classlabel in zip(range(len(self.classlabels)), self.classlabels)}
237-
238-
return scoredict
243+
df = pd.DataFrame(predictions, columns=self.classlabels)
244+
245+
if not is_multiple:
246+
return df.to_dict('records')[0]
247+
248+
return df.to_dict('records')
239249

240250

241251
def load_varnnlibvec_classifier(wvmodel, name, compact=True, vecsize=None):

0 commit comments

Comments
 (0)