Skip to content

Commit 7c50129

Browse files
authored
Perform a lazy import of optional dep nltk (#311)
1 parent 32a5444 commit 7c50129

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

Diff for: redisvl/query/aggregate.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Any, Dict, List, Optional, Set, Tuple, Union
22

3-
import nltk
4-
from nltk.corpus import stopwords as nltk_stopwords
53
from redis.commands.search.aggregation import AggregateRequest, Desc
64

75
from redisvl.query.filter import FilterExpression
@@ -164,6 +162,14 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
164162
if not stopwords:
165163
self._stopwords = set()
166164
elif isinstance(stopwords, str):
165+
# Lazy import because nltk is an optional dependency
166+
try:
167+
import nltk
168+
from nltk.corpus import stopwords as nltk_stopwords
169+
except ImportError:
170+
raise ValueError(
171+
f"Loading stopwords for {stopwords} failed: nltk is not installed."
172+
)
167173
try:
168174
nltk.download("stopwords", quiet=True)
169175
self._stopwords = set(nltk_stopwords.words(stopwords))

Diff for: redisvl/query/query.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from enum import Enum
22
from typing import Any, Dict, List, Optional, Set, Tuple, Union
33

4-
import nltk
5-
from nltk.corpus import stopwords as nltk_stopwords
64
from redis.commands.search.query import Query as RedisQuery
75

86
from redisvl.query.filter import FilterExpression
@@ -812,6 +810,14 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
812810
if not stopwords:
813811
self._stopwords = set()
814812
elif isinstance(stopwords, str):
813+
# Lazy import because nltk is an optional dependency
814+
try:
815+
import nltk
816+
from nltk.corpus import stopwords as nltk_stopwords
817+
except ImportError:
818+
raise ValueError(
819+
f"Loading stopwords for {stopwords} failed: nltk is not installed."
820+
)
815821
try:
816822
nltk.download("stopwords", quiet=True)
817823
self._stopwords = set(nltk_stopwords.words(stopwords))

0 commit comments

Comments
 (0)