Skip to content

Commit f83f387

Browse files
Replace messy neighbor search with Affinity class
1 parent 005982d commit f83f387

File tree

7 files changed

+229
-204
lines changed

7 files changed

+229
-204
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ If we want finer control of the optimization process, we can run individual opti
5555

5656
```python
5757
tsne = TSNE()
58-
embedding = tsne.get_initial_embedding_for(x)
58+
embedding = tsne.prepare_initial(x)
5959
embedding.optimize(n_iter=250, exaggeration=12, momentum=0.5)
6060
embedding.optimize(n_iter=750, momentum=0.8)
6161
```

tests/test_correctness.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def setUpClass(cls):
1515
cls.x = np.random.randn(100, 4)
1616

1717
def test_error_exaggeration_correction(self):
18-
embedding = self.tsne.get_initial_embedding_for(self.x)
18+
embedding = self.tsne.prepare_initial(self.x)
1919

2020
# The callback raises if the KL divergence does not match the true one
2121
embedding.optimize(

tests/test_tsne.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_embedding_optimize(self, param_name, param_value, gradient_descent):
132132
params = {'n_iter': 50, param_name: param_value}
133133

134134
tsne = TSNE()
135-
embedding = tsne.get_initial_embedding_for(self.x)
135+
embedding = tsne.prepare_initial(self.x)
136136
embedding.optimize(**params, inplace=True)
137137

138138
self.assertEqual(1, gradient_descent.call_count)
@@ -204,7 +204,7 @@ def test_partial_embedding_optimize(self, param_name, param_value, gradient_desc
204204
# `optimize` requires us to specify the `n_iter`
205205
params = {'n_iter': 50, param_name: param_value}
206206

207-
partial_embedding = embedding.get_partial_embedding_for(self.x_test)
207+
partial_embedding = embedding.prepare_partial(self.x_test)
208208
partial_embedding.optimize(**params, inplace=True)
209209

210210
self.assertEqual(1, gradient_descent.call_count)
@@ -219,7 +219,7 @@ def setUpClass(cls):
219219
cls.x_test = np.random.randn(25, 4)
220220

221221
def test_embedding_inplace_optimization(self):
222-
embedding1 = self.tsne.get_initial_embedding_for(self.x)
222+
embedding1 = self.tsne.prepare_initial(self.x)
223223

224224
embedding2 = embedding1.optimize(n_iter=5, inplace=True)
225225
embedding3 = embedding2.optimize(n_iter=5, inplace=True)
@@ -228,7 +228,7 @@ def test_embedding_inplace_optimization(self):
228228
self.assertIs(embedding2.base, embedding3.base)
229229

230230
def test_embedding_not_inplace_optimization(self):
231-
embedding1 = self.tsne.get_initial_embedding_for(self.x)
231+
embedding1 = self.tsne.prepare_initial(self.x)
232232

233233
embedding2 = embedding1.optimize(n_iter=5, inplace=False)
234234
embedding3 = embedding2.optimize(n_iter=5, inplace=False)
@@ -239,10 +239,10 @@ def test_embedding_not_inplace_optimization(self):
239239

240240
def test_partial_embedding_inplace_optimization(self):
241241
# Prepare reference embedding
242-
embedding = self.tsne.get_initial_embedding_for(self.x)
242+
embedding = self.tsne.prepare_initial(self.x)
243243
embedding.optimize(10, inplace=True)
244244

245-
partial_embedding1 = embedding.get_partial_embedding_for(self.x_test)
245+
partial_embedding1 = embedding.prepare_partial(self.x_test)
246246
partial_embedding2 = partial_embedding1.optimize(5, inplace=True)
247247
partial_embedding3 = partial_embedding2.optimize(5, inplace=True)
248248

@@ -251,10 +251,10 @@ def test_partial_embedding_inplace_optimization(self):
251251

252252
def test_partial_embedding_not_inplace_optimization(self):
253253
# Prepare reference embedding
254-
embedding = self.tsne.get_initial_embedding_for(self.x)
254+
embedding = self.tsne.prepare_initial(self.x)
255255
embedding.optimize(10, inplace=True)
256256

257-
partial_embedding1 = embedding.get_partial_embedding_for(self.x_test)
257+
partial_embedding1 = embedding.prepare_partial(self.x_test)
258258
partial_embedding2 = partial_embedding1.optimize(5, inplace=False)
259259
partial_embedding3 = partial_embedding2.optimize(5, inplace=False)
260260

@@ -298,7 +298,7 @@ def test_can_pass_callbacks_to_tsne_object(self):
298298
callback2.assert_called_once()
299299

300300
def test_can_pass_callbacks_to_embedding_optimize(self):
301-
embedding = self.tsne.get_initial_embedding_for(self.x)
301+
embedding = self.tsne.prepare_initial(self.x)
302302

303303
# We don't the callback to be iterable
304304
callback = MagicMock()
@@ -314,7 +314,7 @@ def test_can_pass_callbacks_to_embedding_optimize(self):
314314
callback.assert_called_once()
315315

316316
def test_can_pass_callbacks_to_embedding_transform(self):
317-
embedding = self.tsne.get_initial_embedding_for(self.x)
317+
embedding = self.tsne.prepare_initial(self.x)
318318

319319
# We don't the callback to be iterable
320320
callback = MagicMock()
@@ -332,14 +332,14 @@ def test_can_pass_callbacks_to_embedding_transform(self):
332332
callback.assert_called_once()
333333

334334
def test_can_pass_callbacks_to_partial_embedding_optimize(self):
335-
embedding = self.tsne.get_initial_embedding_for(self.x)
335+
embedding = self.tsne.prepare_initial(self.x)
336336

337337
# We don't the callback to be iterable
338338
callback = MagicMock()
339339
del callback.__iter__
340340

341341
# Should be able to pass a single callback
342-
partial_embedding = embedding.get_partial_embedding_for(self.x_test)
342+
partial_embedding = embedding.prepare_partial(self.x_test)
343343
partial_embedding.optimize(1, callbacks=callback, callbacks_every_iters=1)
344344
callback.assert_called_once()
345345

tsne/affinity.py

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import logging
2+
3+
import numpy as np
4+
from scipy.sparse import csr_matrix
5+
6+
from tsne import _tsne
7+
from tsne.nearest_neighbors import KDTree, NNDescent, KNNIndex
8+
9+
log = logging.getLogger(__name__)
10+
11+
12+
class Affinities:
13+
"""Compute the affinities among some initial data and new data.
14+
15+
tSNE takes as input an affinity matrix P, and does not really care about
16+
the space in which the original data points lie. This means we are not
17+
limited to problems with numeric matrices (although that is the most common
18+
use-case) but can also optimize graph layouts.
19+
20+
We use perplexity, as defined by Van der Maaten in the original paper as a
21+
continuous analogue to the number of neighbor affinities we want to
22+
preserve during optimization.
23+
24+
"""
25+
def __init__(self, perplexity=30):
26+
self.perplexity = perplexity
27+
self.P = None
28+
29+
def to_new(self, data, perplexity=None, return_distances=False):
30+
"""Compute the affinities of new data points to the existing ones.
31+
32+
This is especially useful for `transform` where we need the conditional
33+
probabilities from the existing to the new data.
34+
35+
"""
36+
37+
38+
class NearestNeighborAffinities(Affinities):
39+
"""Compute affinities using the nearest neighbors defined by perplexity."""
40+
def __init__(self, data, perplexity=30, method='approx', metric='euclidean',
41+
symmetrize=True, n_jobs=1):
42+
self.n_samples = data.shape[0]
43+
44+
perplexity = self.check_perplexity(perplexity)
45+
k_neighbors = min(self.n_samples - 1, int(3 * perplexity))
46+
47+
# Support shortcuts for built-in nearest neighbor methods
48+
methods = {'exact': KDTree, 'approx': NNDescent}
49+
if isinstance(method, KNNIndex):
50+
knn_index = method
51+
52+
elif method not in methods:
53+
raise ValueError('Unrecognized nearest neighbor algorithm `%s`. '
54+
'Please choose one of the supported methods or '
55+
'provide a valid `KNNIndex` instance.')
56+
else:
57+
knn_index = methods[method](metric=metric, n_jobs=n_jobs)
58+
59+
knn_index.build(data)
60+
neighbors, distances = knn_index.query_train(data, k=k_neighbors)
61+
62+
# Store the results on the object
63+
self.perplexity = perplexity
64+
self.knn_index = knn_index
65+
self.P = joint_probabilities_nn(
66+
neighbors, distances, perplexity, symmetrize=symmetrize, n_jobs=n_jobs)
67+
68+
self.n_jobs = n_jobs
69+
70+
def to_new(self, data, perplexity=None, return_distances=False):
71+
perplexity = perplexity or self.perplexity
72+
perplexity = self.check_perplexity(perplexity)
73+
k_neighbors = min(self.n_samples - 1, int(3 * perplexity))
74+
75+
neighbors, distances = self.knn_index.query(data, k_neighbors)
76+
77+
P = joint_probabilities_nn(
78+
neighbors, distances, perplexity, symmetrize=False,
79+
n_reference_samples=self.n_samples, n_jobs=self.n_jobs,
80+
)
81+
82+
if return_distances:
83+
return P, neighbors, distances
84+
85+
return P
86+
87+
def check_perplexity(self, perplexity):
88+
"""Check for valid perplexity value."""
89+
if self.n_samples - 1 < 3 * perplexity:
90+
old_perplexity, perplexity = perplexity, (self.n_samples - 1) / 3
91+
log.warning('Perplexity value %d is too high. Using perplexity %.2f' %
92+
(old_perplexity, perplexity))
93+
94+
return perplexity
95+
96+
97+
class GraphAffinities(Affinities):
98+
def __init__(self, data, use_directed=True, use_weights=True):
99+
super().__init__()
100+
101+
def to_new(self, data):
102+
pass
103+
104+
105+
def joint_probabilities_nn(neighbors, distances, perplexity, symmetrize=True,
106+
n_reference_samples=None, n_jobs=1):
107+
"""Compute the conditional probability matrix P_{j|i}.
108+
109+
This method computes an approximation to P using the nearest neighbors.
110+
111+
Parameters
112+
----------
113+
neighbors : np.ndarray
114+
A `n_samples * k_neighbors` matrix containing the indices to each
115+
points' nearest neighbors in descending order.
116+
distances : np.ndarray
117+
A `n_samples * k_neighbors` matrix containing the distances to the
118+
neighbors at indices defined in the neighbors parameter.
119+
perplexity : double
120+
The desired perplexity of the probability distribution.
121+
symmetrize : bool
122+
Whether to symmetrize the probability matrix or not. Symmetrizing is
123+
used for typical t-SNE, but does not make sense when embedding new data
124+
into an existing embedding.
125+
n_reference_samples : int
126+
The number of samples in the existing (reference) embedding. Needed to
127+
properly construct the sparse P matrix.
128+
n_jobs : int
129+
Number of threads.
130+
131+
Returns
132+
-------
133+
csr_matrix
134+
A `n_samples * n_reference_samples` matrix containing the probabilities
135+
that a new sample would appear as a neighbor of a reference point.
136+
137+
"""
138+
n_samples, k_neighbors = distances.shape
139+
140+
if n_reference_samples is None:
141+
n_reference_samples = n_samples
142+
143+
# Compute asymmetric pairwise input similarities
144+
conditional_P = _tsne.compute_gaussian_perplexity(
145+
distances, perplexity, num_threads=n_jobs)
146+
conditional_P = np.asarray(conditional_P)
147+
148+
P = csr_matrix((conditional_P.ravel(), neighbors.ravel(),
149+
range(0, n_samples * k_neighbors + 1, k_neighbors)),
150+
shape=(n_samples, n_reference_samples))
151+
152+
# Symmetrize the probability matrix
153+
if symmetrize:
154+
P = (P + P.T) / 2
155+
156+
# Convert weights to probabilities using pair-wise normalization scheme
157+
P /= np.sum(P)
158+
159+
return P

tsne/callbacks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ class VerifyExaggerationError:
3939
def __init__(self, embedding: TSNEEmbedding) -> None:
4040
self.embedding = embedding
4141
# Keep a copy of the unexaggerated affinity matrix
42-
self.P = self.embedding.P.copy()
42+
self.P = self.embedding.affinities.P.copy()
4343

4444
def __call__(self, iteration: int, corrected_error: float, embedding: TSNEEmbedding):
4545
params = self.embedding.gradient_descent_params
4646
method = params['negative_gradient_method']
4747

48-
if np.sum(embedding.P) <= 1:
48+
if np.sum(embedding.affinities.P) <= 1:
4949
log.warning('Are you sure you are testing an exaggerated P matrix?')
5050

5151
if method == 'fft':

tsne/metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ def pBIC(embedding: TSNEEmbedding) -> float:
77
n_samples = embedding.shape[0]
88

99
return 2 * embedding.kl_divergence + np.log(n_samples) * \
10-
embedding.perplexity / n_samples
10+
embedding.affinities.perplexity / n_samples

0 commit comments

Comments
 (0)