Skip to content

Commit 9e1cbb3

Browse files
Improve natural language inference performance (#1234)
* improve natural language inference performance * update release notes * update comment * add new test * lint fix
1 parent 0591279 commit 9e1cbb3

File tree

3 files changed

+70
-5
lines changed

3 files changed

+70
-5
lines changed

docs/source/release_notes.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ Future Release
88
* Enhancements
99
* Fixes
1010
* Changes
11+
* Update inference process to only check for NaturalLanguage if no other type matches are found first (:pr:`1234`)
1112
* Documentation Changes
12-
* Updating contributing doc with Spark installat instructions (:pr:`1232`)
13+
* Updating contributing doc with Spark installation instructions (:pr:`1232`)
1314
* Testing Changes
1415
* Enable auto-merge for minimum and latest dependency merge requests (:pr:`1228`, :pr:`1230`, :pr:`1233`)
1516

1617
Thanks to the following people for contributing to this release:
17-
:user:`gsheni`, :user:`willsmithorg`
18+
:user:`gsheni`, :user:`thehomebrewnerd`, :user:`willsmithorg`
1819

1920
v0.11.0 Dec 22, 2021
2021
====================

woodwork/tests/type_system/test_ltype_inference.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from mock import patch
2+
13
import woodwork as ww
24
from woodwork.accessor_utils import _is_koalas_series
35
from woodwork.logical_types import (
@@ -15,6 +17,12 @@
1517
Unknown,
1618
)
1719
from woodwork.tests.testing_utils import to_pandas
20+
from woodwork.type_sys.type_system import (
21+
DEFAULT_INFERENCE_FUNCTIONS,
22+
DEFAULT_RELATIONSHIPS,
23+
DEFAULT_TYPE,
24+
TypeSystem,
25+
)
1826
from woodwork.utils import import_or_none
1927

2028
UNSUPPORTED_KOALAS_DTYPES = [
@@ -125,6 +133,47 @@ def test_natural_language_inference(natural_language):
125133
assert isinstance(inferred_type, NaturalLanguage)
126134

127135

136+
@patch("woodwork.type_sys.inference_functions.natural_language_func")
137+
def test_nl_inference_called_on_no_other_matches(nl_mock, pandas_natural_language):
138+
assert isinstance(
139+
ww.type_system.infer_logical_type(pandas_natural_language[0]), NaturalLanguage
140+
)
141+
new_type_sys = TypeSystem(
142+
inference_functions=DEFAULT_INFERENCE_FUNCTIONS,
143+
relationships=DEFAULT_RELATIONSHIPS,
144+
default_type=DEFAULT_TYPE,
145+
)
146+
new_type_sys.inference_functions[NaturalLanguage] = nl_mock
147+
_ = new_type_sys.infer_logical_type(pandas_natural_language[0])
148+
assert nl_mock.called
149+
150+
151+
@patch("woodwork.type_sys.inference_functions.natural_language_func")
152+
def test_nl_inference_called_with_unknown_type(nl_mock, pandas_strings):
153+
assert isinstance(ww.type_system.infer_logical_type(pandas_strings[0]), Unknown)
154+
new_type_sys = TypeSystem(
155+
inference_functions=DEFAULT_INFERENCE_FUNCTIONS,
156+
relationships=DEFAULT_RELATIONSHIPS,
157+
default_type=DEFAULT_TYPE,
158+
)
159+
new_type_sys.inference_functions[NaturalLanguage] = nl_mock
160+
_ = new_type_sys.infer_logical_type(pandas_strings[0])
161+
assert nl_mock.called
162+
163+
164+
@patch("woodwork.type_sys.inference_functions.natural_language_func")
165+
def test_nl_inference_not_called_with_other_matches(nl_mock, pandas_integers):
166+
assert isinstance(ww.type_system.infer_logical_type(pandas_integers[0]), Integer)
167+
new_type_sys = TypeSystem(
168+
inference_functions=DEFAULT_INFERENCE_FUNCTIONS,
169+
relationships=DEFAULT_RELATIONSHIPS,
170+
default_type=DEFAULT_TYPE,
171+
)
172+
new_type_sys.inference_functions[NaturalLanguage] = nl_mock
173+
_ = new_type_sys.infer_logical_type(pandas_integers[0])
174+
assert not nl_mock.called
175+
176+
128177
def test_categorical_inference_based_on_dtype(categories_dtype):
129178
"""
130179
This test specifically targets the case in which a series can be inferred

woodwork/type_sys/type_system.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,26 @@ def get_inference_matches(types_to_check, series, type_matches=[]):
315315
get_inference_matches(check_next, series, type_matches)
316316
return type_matches
317317

318-
type_matches = get_inference_matches(self.root_types, series)
318+
# Don't include NaturalLanguage as we only want to check that if
319+
# no other matches are found
320+
types_to_check = [
321+
ltype for ltype in self.root_types if ltype != NaturalLanguage
322+
]
323+
type_matches = get_inference_matches(types_to_check, series)
319324

320325
if len(type_matches) == 0:
321-
# If no matches, set type to default type (Unknown)
322-
logical_type = self.default_type
326+
# Check if this is NaturalLanguage, otherwise set
327+
# type to default type (Unknown). Assume that a column
328+
# can only be natural language if it is not already a
329+
# match for another type. Also improves performance by
330+
# limiting the times the natural language inference function
331+
# is called.
332+
if self.inference_functions.get(
333+
NaturalLanguage
334+
) and self.inference_functions[NaturalLanguage](series):
335+
logical_type = NaturalLanguage
336+
else:
337+
logical_type = self.default_type
323338
elif len(type_matches) == 1:
324339
# If we match only one type, return it
325340
logical_type = type_matches[0]

0 commit comments

Comments
 (0)