|
4 | 4 | import numpy as np |
5 | 5 | import scipy.sparse as sp |
6 | 6 | from scipy.spatial.distance import pdist, cdist, squareform |
7 | | -import pynndescent |
8 | | -import hnswlib |
9 | 7 | from sklearn import datasets |
10 | 8 |
|
11 | | -from numba import njit |
12 | | -from numba.core.registry import CPUDispatcher |
13 | 9 | from sklearn.utils import check_random_state |
14 | 10 |
|
15 | 11 | from openTSNE import nearest_neighbors |
| 12 | +from openTSNE.utils import is_package_installed |
16 | 13 | from .test_tsne import check_mock_called_with_kwargs |
17 | 14 |
|
18 | 15 |
|
@@ -76,10 +73,6 @@ class TestAnnoy(KNNIndexTestMixin, unittest.TestCase): |
76 | 73 | knn_index = nearest_neighbors.Annoy |
77 | 74 |
|
78 | 75 |
|
79 | | -class TestHNSW(KNNIndexTestMixin, unittest.TestCase): |
80 | | - knn_index = nearest_neighbors.HNSW |
81 | | - |
82 | | - |
83 | 76 | class TestBallTree(KNNIndexTestMixin, unittest.TestCase): |
84 | 77 | knn_index = nearest_neighbors.BallTree |
85 | 78 |
|
@@ -145,43 +138,37 @@ def manhattan(x, y): |
145 | 138 | distances, true_distances_, err_msg="Distances do not match" |
146 | 139 | ) |
147 | 140 |
|
148 | | - def test_numba_compiled_callable_metric_same_result(self): |
149 | | - k = 15 |
150 | 141 |
|
151 | | - knn_index = self.knn_index("manhattan", random_state=1) |
152 | | - knn_index.build(self.x1, k=k) |
153 | | - true_indices_, true_distances_ = knn_index.query(self.x2, k=k) |
154 | | - |
155 | | - @njit(fastmath=True) |
156 | | - def manhattan(x, y): |
157 | | - result = 0.0 |
158 | | - for i in range(x.shape[0]): |
159 | | - result += np.abs(x[i] - y[i]) |
160 | | - |
161 | | - return result |
| 142 | +@unittest.skipIf(not is_package_installed("hnswlib"), "`hnswlib`is not installed") |
| 143 | +class TestHNSW(KNNIndexTestMixin, unittest.TestCase): |
| 144 | + knn_index = nearest_neighbors.HNSW |
162 | 145 |
|
163 | | - knn_index = self.knn_index(manhattan, random_state=1) |
164 | | - knn_index.build(self.x1, k=k) |
165 | | - indices, distances = knn_index.query(self.x2, k=k) |
166 | | - np.testing.assert_array_equal( |
167 | | - indices, true_indices_, err_msg="Nearest neighbors do not match" |
168 | | - ) |
169 | | - np.testing.assert_allclose( |
170 | | - distances, true_distances_, err_msg="Distances do not match" |
171 | | - ) |
| 146 | + @classmethod |
| 147 | + def setUpClass(cls): |
| 148 | + global hnswlib |
| 149 | + import hnswlib |
172 | 150 |
|
173 | 151 |
|
| 152 | +@unittest.skipIf(not is_package_installed("pynndescent"), "`pynndescent`is not installed") |
174 | 153 | class TestNNDescent(KNNIndexTestMixin, unittest.TestCase): |
175 | 154 | knn_index = nearest_neighbors.NNDescent |
176 | 155 |
|
177 | | - @patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) |
178 | | - def test_random_state_being_passed_through(self, nndescent): |
| 156 | + @classmethod |
| 157 | + def setUpClass(cls): |
| 158 | + global pynndescent, njit, CPUDispatcher |
| 159 | + |
| 160 | + import pynndescent |
| 161 | + from numba import njit |
| 162 | + from numba.core.registry import CPUDispatcher |
| 163 | + |
| 164 | + def test_random_state_being_passed_through(self): |
179 | 165 | random_state = 1 |
180 | | - knn_index = nearest_neighbors.NNDescent("euclidean", random_state=random_state) |
181 | | - knn_index.build(self.x1, k=30) |
| 166 | + with patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) as nndescent: |
| 167 | + knn_index = nearest_neighbors.NNDescent("euclidean", random_state=random_state) |
| 168 | + knn_index.build(self.x1, k=30) |
182 | 169 |
|
183 | | - nndescent.assert_called_once() |
184 | | - check_mock_called_with_kwargs(nndescent, {"random_state": random_state}) |
| 170 | + nndescent.assert_called_once() |
| 171 | + check_mock_called_with_kwargs(nndescent, {"random_state": random_state}) |
185 | 172 |
|
186 | 173 | def test_uncompiled_callable_is_compiled(self): |
187 | 174 | knn_index = nearest_neighbors.NNDescent("manhattan") |
@@ -245,47 +232,47 @@ def manhattan(x, y): |
245 | 232 | distances, true_distances_, err_msg="Distances do not match" |
246 | 233 | ) |
247 | 234 |
|
248 | | - @patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) |
249 | | - def test_building_with_lt15_builds_proper_graph(self, nndescent): |
250 | | - knn_index = nearest_neighbors.NNDescent("euclidean") |
251 | | - indices, distances = knn_index.build(self.x1, k=10) |
| 235 | + def test_building_with_lt15_builds_proper_graph(self): |
| 236 | + with patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) as nndescent: |
| 237 | + knn_index = nearest_neighbors.NNDescent("euclidean") |
| 238 | + indices, distances = knn_index.build(self.x1, k=10) |
252 | 239 |
|
253 | | - self.assertEqual(indices.shape, (self.x1.shape[0], 10)) |
254 | | - self.assertEqual(distances.shape, (self.x1.shape[0], 10)) |
255 | | - self.assertFalse(np.all(indices[:, 0] == np.arange(self.x1.shape[0]))) |
| 240 | + self.assertEqual(indices.shape, (self.x1.shape[0], 10)) |
| 241 | + self.assertEqual(distances.shape, (self.x1.shape[0], 10)) |
| 242 | + self.assertFalse(np.all(indices[:, 0] == np.arange(self.x1.shape[0]))) |
256 | 243 |
|
257 | 244 | # Should be called with 11 because nearest neighbor in pynndescent is itself |
258 | 245 | check_mock_called_with_kwargs(nndescent, dict(n_neighbors=11)) |
259 | 246 |
|
260 | | - @patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) |
261 | | - def test_building_with_gt15_calls_query(self, nndescent): |
262 | | - nndescent.query = MagicMock(wraps=nndescent.query) |
263 | | - knn_index = nearest_neighbors.NNDescent("euclidean") |
264 | | - indices, distances = knn_index.build(self.x1, k=30) |
265 | | - |
266 | | - self.assertEqual(indices.shape, (self.x1.shape[0], 30)) |
267 | | - self.assertEqual(distances.shape, (self.x1.shape[0], 30)) |
268 | | - self.assertFalse(np.all(indices[:, 0] == np.arange(self.x1.shape[0]))) |
269 | | - |
270 | | - # The index should be built with 15 neighbors |
271 | | - check_mock_called_with_kwargs(nndescent, dict(n_neighbors=15)) |
272 | | - # And subsequently queried with the correct number of neighbors. Check |
273 | | - # for 31 neighbors because query will return the original point as well, |
274 | | - # which we don't consider. |
275 | | - check_mock_called_with_kwargs(nndescent.query, dict(k=31)) |
276 | | - |
277 | | - @patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) |
278 | | - def test_runs_with_correct_njobs_if_dense_input(self, nndescent): |
279 | | - knn_index = nearest_neighbors.NNDescent("euclidean", n_jobs=2) |
280 | | - knn_index.build(self.x1, k=5) |
281 | | - check_mock_called_with_kwargs(nndescent, dict(n_jobs=2)) |
282 | | - |
283 | | - @patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) |
284 | | - def test_runs_with_correct_njobs_if_sparse_input(self, nndescent): |
285 | | - x_sparse = sp.csr_matrix(self.x1) |
286 | | - knn_index = nearest_neighbors.NNDescent("euclidean", n_jobs=2) |
287 | | - knn_index.build(x_sparse, k=5) |
288 | | - check_mock_called_with_kwargs(nndescent, dict(n_jobs=2)) |
| 247 | + def test_building_with_gt15_calls_query(self): |
| 248 | + with patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) as nndescent: |
| 249 | + nndescent.query = MagicMock(wraps=nndescent.query) |
| 250 | + knn_index = nearest_neighbors.NNDescent("euclidean") |
| 251 | + indices, distances = knn_index.build(self.x1, k=30) |
| 252 | + |
| 253 | + self.assertEqual(indices.shape, (self.x1.shape[0], 30)) |
| 254 | + self.assertEqual(distances.shape, (self.x1.shape[0], 30)) |
| 255 | + self.assertFalse(np.all(indices[:, 0] == np.arange(self.x1.shape[0]))) |
| 256 | + |
| 257 | + # The index should be built with 15 neighbors |
| 258 | + check_mock_called_with_kwargs(nndescent, dict(n_neighbors=15)) |
| 259 | + # And subsequently queried with the correct number of neighbors. Check |
| 260 | + # for 31 neighbors because query will return the original point as well, |
| 261 | + # which we don't consider. |
| 262 | + check_mock_called_with_kwargs(nndescent.query, dict(k=31)) |
| 263 | + |
| 264 | + def test_runs_with_correct_njobs_if_dense_input(self): |
| 265 | + with patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) as nndescent: |
| 266 | + knn_index = nearest_neighbors.NNDescent("euclidean", n_jobs=2) |
| 267 | + knn_index.build(self.x1, k=5) |
| 268 | + check_mock_called_with_kwargs(nndescent, dict(n_jobs=2)) |
| 269 | + |
| 270 | + def test_runs_with_correct_njobs_if_sparse_input(self): |
| 271 | + with patch("pynndescent.NNDescent", wraps=pynndescent.NNDescent) as nndescent: |
| 272 | + x_sparse = sp.csr_matrix(self.x1) |
| 273 | + knn_index = nearest_neighbors.NNDescent("euclidean", n_jobs=2) |
| 274 | + knn_index.build(x_sparse, k=5) |
| 275 | + check_mock_called_with_kwargs(nndescent, dict(n_jobs=2)) |
289 | 276 |
|
290 | 277 | def test_random_cluster_when_invalid_indices(self): |
291 | 278 | class MockIndex: |
|
0 commit comments