Skip to content

Commit 25d48d8

Browse files
authored
Merge pull request #11 from beringresearch/remove-supervised
Remove supervised
2 parents 620cbbc + e8df57a commit 25d48d8

File tree

7 files changed

+1097
-1085
lines changed

7 files changed

+1097
-1085
lines changed

R-package/R/ivis.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#' IVIS algorithm
22
#'
33
#' @param X numerical matrix to be reduced. Columns correspond to features.
4-
#' @param y int, optional (default: NULL). Optional class vector triggering supervised tripplet selection.
54
#' @param embedding_dims int, optional (default: 2) Number of dimensions in the embedding space
65
#' @param k int, optional (default: 150)
76
#' The number of neighbours to retrieve for each point
@@ -24,7 +23,7 @@
2423
#' Whether to pre-compute the nearest neighbours. Pre-computing is significantly faster, but requires more memory. If memory is limited, try setting this to False.
2524
#' @export
2625

27-
ivis <- function(X, y = NULL, embedding_dims = 2L,
26+
ivis <- function(X, embedding_dims = 2L,
2827
k = 150L,
2928
distance = "pn",
3029
batch_size = 128L,
@@ -52,7 +51,7 @@ ivis <- function(X, y = NULL, embedding_dims = 2L,
5251
epochs = epochs, n_epochs_without_progress = n_epochs_without_progress,
5352
margin = margin, ntrees = ntrees, search_k = search_k, precompute = precompute)
5453

55-
embeddings = model$fit_transform(X = X, y = y)
54+
embeddings = model$fit_transform(X = X)
5655
return(embeddings)
5756

5857
}

README.md

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ After cloning this repo run: `pip install -r requirements.txt --editable .` from
1010

1111
## Examples
1212

13-
Ivis can be run in both unsupervised and supervised mode. To run in supservised mode, simply provide an array of labels to the .fit() method.
14-
1513
### Unsupervised embeddings
1614

1715
```
@@ -20,7 +18,6 @@ from sklearn import datasets
2018
2119
iris = datasets.load_iris()
2220
X = iris.data
23-
y = iris.target
2421
2522
model = Ivis(embedding_dims=2, k=15)
2623
@@ -31,23 +28,6 @@ Plotting the embeddings results in the following visualization:
3128

3229
![](docs/ivis-iris-demo.png)
3330

34-
### Supervised embeddings
35-
36-
```
37-
from keras.datasets import mnist
38-
import numpy as np
39-
from ivis import Ivis
40-
41-
(x_train, y_train), (x_test, y_test) = mnist.load_data()
42-
x_test = np.reshape(x_test.astype('float32'), (len(x_test), 28 * 28))
43-
44-
45-
model = Ivis()
46-
embeddings = model.fit_transform(x_test, y_test)
47-
```
48-
49-
![](docs/ivis_mnist_supervised_embeddings.png)
50-
5131
### Training an a .h5 dataset
5232

5333
Load the data using a HDF5Matrix object provided by keras.

examples/.ipynb_checkpoints/iris_dimensionality_reduction-checkpoint.ipynb

Lines changed: 533 additions & 486 deletions
Large diffs are not rendered by default.

examples/iris_dimensionality_reduction.ipynb

Lines changed: 533 additions & 486 deletions
Large diffs are not rendered by default.

ivis/data/triplet_generators.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,6 @@ def create_triplet_generator_from_annoy_index(X, index, k, batch_size, search_k=
6161
return generate_knn_triplets_from_annoy_index(X, index, k=k, batch_size=batch_size, search_k=search_k)
6262

6363

64-
def create_triplet_generator_from_labels(X, y, batch_size):
65-
return generate_triplets_from_labels(X, np.array(y), batch_size=batch_size)
66-
67-
6864
def knn_triplet_from_neighbour_list(X, index, neighbour_list):
6965
""" A random (unweighted) positive example chosen. """
7066
N_ROWS = X.shape[0]
@@ -154,53 +150,6 @@ def generate_knn_triplets_from_annoy_index(X, annoy_index, k=150, batch_size=32,
154150
triplet_batch = np.array(triplet_batch)
155151
yield ([triplet_batch[:,0], triplet_batch[:,1], triplet_batch[:,2]], placeholder_labels)
156152

157-
@threadsafe_generator
158-
def generate_triplets_from_labels(X, Y, batch_size=32):
159-
N_ROWS = X.shape[0]
160-
iterations = 0
161-
row_indexes = np.array(list(range(N_ROWS)), dtype=np.uint32)
162-
np.random.shuffle(row_indexes)
163-
164-
placeholder_labels = np.array([0 for i in range(batch_size)])
165-
166-
while True:
167-
triplet_batch = []
168-
169-
for i in range(batch_size):
170-
if iterations >= N_ROWS:
171-
np.random.shuffle(row_indexes)
172-
iterations = 0
173-
174-
triplet = triplet_from_labels(X, Y, row_indexes[iterations])
175-
176-
triplet_batch += triplet
177-
iterations += 1
178-
179-
if (issparse(X)):
180-
triplet_batch = [[e.toarray()[0] for e in t] for t in triplet_batch]
181-
182-
triplet_batch = np.array(triplet_batch)
183-
yield ([triplet_batch[:,0], triplet_batch[:,1], triplet_batch[:,2]], placeholder_labels)
184-
185-
def triplet_from_labels(X, Y, index):
186-
""" A random (unweighted) positive example chosen. """
187-
N_ROWS = X.shape[0]
188-
triplets = []
189-
190-
row_label = Y[index]
191-
neighbour_indexes = np.where(Y == row_label)[0]
192-
193-
# Take a random neighbour as positive
194-
neighbour_ind = np.random.choice(neighbour_indexes)
195-
196-
# Take a random non-neighbour as negative
197-
negative_ind = np.random.randint(0, N_ROWS) # Pick a random index until one fits constraint. An optimization.
198-
while negative_ind in neighbour_indexes:
199-
negative_ind = np.random.randint(0, N_ROWS)
200-
201-
triplets += [[X[index], X[neighbour_ind], X[negative_ind]]]
202-
return triplets
203-
204153
def create_triplets_from_positive_index_dict(X, positive_index_dict):
205154
N_ROWS = X.shape[0]
206155
triplets = []

ivis/ivis.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
""" scikit-learn wrapper class for the Ivis algorithm. """
22

3-
from .data.triplet_generators import create_triplet_generator_from_annoy_index, create_triplet_generator_from_labels
3+
from .data.triplet_generators import create_triplet_generator_from_annoy_index
44
from .nn.network import build_network, selu_base_network
55
from .nn.losses import triplet_loss
66
from .data.knn import build_annoy_index
@@ -80,26 +80,18 @@ def __init__(self, embedding_dims=2, k=150, distance='pn', batch_size=128, epoch
8080
self.model_ = model
8181
self.annoy_index = annoy_index
8282

83-
def _fit(self, X, y, val_x, val_y, shuffle_mode=True):
84-
if y is None:
85-
self.annoy_index = self.annoy_index or build_annoy_index(X, ntrees=self.ntrees)
86-
datagen = create_triplet_generator_from_annoy_index(X, index=self.annoy_index, k=self.k, batch_size=self.batch_size, search_k=self.search_k, precompute=self.precompute)
87-
else:
88-
datagen = create_triplet_generator_from_labels(X, y, batch_size=self.batch_size)
83+
def _fit(self, X, shuffle_mode=True):
84+
85+
self.annoy_index = self.annoy_index or build_annoy_index(X, ntrees=self.ntrees)
86+
datagen = create_triplet_generator_from_annoy_index(X,
87+
index=self.annoy_index,
88+
k=self.k,
89+
batch_size=self.batch_size,
90+
search_k=self.search_k,
91+
precompute=self.precompute)
8992

90-
val_datagen = None
91-
validation_steps = None
9293
loss_monitor = 'loss'
93-
94-
if val_x is not None:
95-
if val_y is None:
96-
val_index = build_annoy_index(val_x, ntrees=self.ntrees)
97-
val_datagen = create_triplet_generator_from_annoy_index(val_x, index=val_index, k=self.k, batch_size=self.batch_size, search_k=self.search_k, precompute=self.precompute)
98-
else:
99-
val_datagen = create_triplet_generator_from_labels(X, y, batch_size=self.batch_size)
100-
101-
validation_steps = int(val_x.shape[0] / self.batch_size)
102-
loss_monitor = 'val_loss'
94+
10395
if self.model_:
10496
model = build_network(self.model_, embedding_dims=self.embedding_dims)
10597
else:
@@ -115,42 +107,40 @@ def _fit(self, X, y, val_x, val_y, shuffle_mode=True):
115107
hist = model.fit_generator(datagen,
116108
steps_per_epoch=int(X.shape[0] / self.batch_size),
117109
epochs=self.epochs,
118-
callbacks=[EarlyStopping(monitor=loss_monitor, patience=self.n_epochs_without_progress)],
119-
validation_data=val_datagen,
120-
validation_steps=validation_steps,
110+
callbacks=[EarlyStopping(monitor=loss_monitor, patience=self.n_epochs_without_progress)],
121111
shuffle=shuffle_mode,
122112
workers=multiprocessing.cpu_count() )
123113
self.loss_history_ = hist.history['loss']
124114
self.model_ = model.layers[3]
125115

126-
def fit(self, X, y=None, val_x=None, val_y=None, shuffle_mode=True):
127-
self._fit(X, y, val_x, val_y, shuffle_mode)
116+
def fit(self, X, shuffle_mode=True):
117+
self._fit(X, shuffle_mode)
128118
return self
129119

130-
def fit_transform(self, X, y=None, val_x=None, val_y=None, shuffle_mode=True):
131-
self.fit(X, y, val_x, val_y, shuffle_mode)
120+
def fit_transform(self, X, shuffle_mode=True):
121+
self.fit(X, shuffle_mode)
132122
return self.transform(X)
133123

134124
def transform(self, X):
135125
embedding = self.model_.predict(X)
136126
return embedding
137127

138-
def save(self, filepath):
128+
def save_model(self, filepath):
139129
self.model_.save(filepath)
140130

141-
def load(self, filepath):
131+
def load_model(self, filepath):
142132
model = load_model(filepath)
143133
self.model_ = model
144134
self.model_._make_predict_function()
145135
return self
146-
147-
def load_index(self, filepath):
148-
annoy_index = AnnoyIndex()
149-
annoy_index.load(filepath)
150-
self.annoy_index = annoy_index
151136

152137
def save_index(self, filepath):
153138
if self.annoy_index is not None:
154139
self.annoy_index.save(filepath)
155140
else:
156-
raise Exception('No annoy index to save.')
141+
raise Exception('No annoy index to save.')
142+
143+
def load_index(self, filepath):
144+
annoy_index = AnnoyIndex()
145+
annoy_index.load(filepath)
146+
self.annoy_index = annoy_index

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
tensorflow==1.10.0
2-
keras==2.2.2
3-
numpy==1.14.3
4-
scikit-learn==0.20.0
5-
tqdm==4.19.4
1+
tensorflow
2+
keras
3+
numpy
4+
scikit-learn>0.20.0
5+
tqdm
66
git+https://github.com/beringresearch/annoy.git#egg=annoy

0 commit comments

Comments
 (0)