Skip to content

Commit 8478283

Browse files
committed
commit
1 parent e60a504 commit 8478283

File tree

5 files changed

+118
-29
lines changed

5 files changed

+118
-29
lines changed

README.md

+29-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,31 @@
1-
# Keras Metric Learning
1+
# Keras Metric Learning Library
22
Deep Metric Learning Library for Keras
33

4-
## Under Construction...
4+
## Welcome
5+
The Keras Metric Learning Library provides the Keras-user with the functionality
6+
to train models with the metric-learning losses being published in the research
7+
literature. See this post for more info.
8+
9+
## Getting Started
10+
Go ahead and clone the repo
11+
```
12+
git clone http://github.com/jricheimer/keras-metric-learning
13+
```
14+
If you'd like to experiment with the Stanford Online Products dataset (a nice
15+
size dataset made for testing metric learning approaches), take a minute to
16+
download it here.
17+
18+
Then, use our script to generate the hdf5 files for the dataset:
19+
```
20+
cd keras-metric-learning
21+
mkdir dataset && cd dataset
22+
python kml_create_stanford_hdf5.py --root_path /path/to/stanford/dataset
23+
```
24+
25+
After that's finished processing, you should have two hdf5 files (one for train,
26+
one for test) in the dataset directory.
27+
28+
## Start training
29+
30+
Take a look at the example notebooks to see how to use the library functionalities.
31+
Enjoy!

kml_callbacks.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from keras.callbacks import Callback
66
import numpy as np
77
from types import GeneratorType
8-
from utils import recall_at_k, nmi
8+
from kml_utils import recall_at_k, nmi
99

1010
class RecallAtK(Callback):
1111
"""Callback that computes the Recall@k metric for a given validation set at the end of each epoch.
@@ -25,11 +25,12 @@ def __init__(self, validation_data, validation_steps=1, k=1, metric='euclidean',
2525
self.model_name = model_name
2626
self.k = k
2727
self.metric = metric
28-
self.validation_data = validation_data
29-
self.validation_steps = validation_steps
28+
# self.validation_data = validation_data
29+
# self.validation_steps = validation_steps
3030
self.verbose = verbose
3131

3232
def on_epoch_end(self, epoch, logs=None):
33+
3334
logs = logs or {}
3435
if 'recall_at_{}'.format(self.k) not in logs:
3536
logs['recall_at_{}'.format(self.k)] = []
@@ -38,24 +39,23 @@ def on_epoch_end(self, epoch, logs=None):
3839
self.model = self.model.get_layer(self.model_name)
3940
else:
4041
sub_models = [l for l in self.model.layers if isinstance(l, Model)]
41-
if len(sub_models) != 1:
42-
raise ValueError('Training network must contain exactly one sub-model')
43-
self.model = sub_models[0]
44-
if isinstance(self.validation_data, GeneratorType):
45-
val_embeddings = []
46-
labels = []
47-
for i in range(self.validation_steps):
48-
data, targets = self.validation_data.next()
49-
val_embeddings.append(self.model.predict(data))
50-
labels.extend(targets)
51-
val_embeddings = np.concatenate(val_embeddings, axis=0)
42+
if len(sub_models) == 1:
43+
self.model = sub_models[0]
44+
# if isinstance(self.validation_data, GeneratorType):
45+
# val_embeddings = []
46+
# labels = []
47+
# for i in range(self.validation_steps):
48+
# data, targets = self.validation_data.next()
49+
# val_embeddings.append(self.model.predict(data))
50+
# labels.extend(targets)
51+
# val_embeddings = np.concatenate(val_embeddings, axis=0)
5252

53-
elif isinstance(self.validation_data, tuple) and len(self.validation_data) == 2:
54-
val_embeddings = self.model.predict(self.validation_data[0])
55-
labels = self.validation_data[1]
53+
# elif isinstance(self.validation_data, tuple) and len(self.validation_data) == 2:
54+
val_embeddings = self.model.predict(self.validation_data[0])
55+
labels = self.validation_data[1]
5656

57-
else:
58-
raise ValueError('validation_data must be either a generator object or a tuple (X,Y)')
57+
# else:
58+
# raise ValueError('validation_data must be either a generator object or a tuple (X,Y)')
5959

6060
recall = recall_at_k(val_embeddings, labels, k=self.k, metric=self.metric)
6161
logs['recall_at_{}'.format(self.k)].append(recall)

kml_data_utils.py

+53-3
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def triplet_generator(data, batch_size=32,
6060
else:
6161
yield x_list
6262

63-
def pair_generator(data, batch_size):
63+
def pair_generator(data, batch_size, all_similar=False,
64+
all_dissimilar=False):
6465
"""Generates pair samples randomly for training Siamese network
6566
6667
# Arguments
@@ -71,13 +72,20 @@ def pair_generator(data, batch_size):
7172
Yields batches of pairs of the form ([batch_1, batch_2], pairwise_labels)
7273
"""
7374
class_ids = data.keys()
75+
if all_dissimilar and all_similar:
76+
raise ValueError()
7477

7578
while True:
7679

7780
batch_list_1 = []
7881
batch_list_2 = []
79-
80-
labels = np.random.randint(2, size=(batch_size,))
82+
83+
if all_similar:
84+
labels = np.ones(shape=(batch_size,))
85+
elif all_dissimilar:
86+
labels = np.zeros(shape=(batch_size,))
87+
else:
88+
labels = np.random.randint(2, size=(batch_size,))
8189
for batch_ind in range(batch_size):
8290

8391
if labels[batch_ind] == 1:
@@ -130,3 +138,45 @@ def structured_batch_generator(data, num_classes_per_batch, num_samples_per_clas
130138
yield (np.stack(batch_list), None)
131139
else:
132140
yield np.stack(batch_list)
141+
142+
def random_sample_generator(data, batch_size=32, label_map=None, classes_per_batch=None, class_to_batch_ratio=None):
143+
"""
144+
# Arguments
145+
data: dict containing numpy arrays for each class, or h5py Group containing h5py Dataset for each class.
146+
batch_size:
147+
label_map: A function that maps the class names (keys of the dataset dict or hdf5 datasets)\
148+
to a class index 0 - (num_classes-1).
149+
classes_per_batch: restricts the sampling to a provided fixed number of classes in each batch
150+
class_to_batch_ratio: Alternative to `classes_per_batch`. If both are specified, `classes_per_batch` will be used.
151+
152+
# Returns
153+
Yields a batch of random samples from the dataset with corresponding class integer labels
154+
"""
155+
class_ids = data.keys()
156+
157+
if not label_map:
158+
# If the class ids are ints, assume they can be used as labels directly
159+
if all([type(i) is int for i in class_ids]) and (max(class_ids) == len(class_ids)-1):
160+
label_map = lambda i: i
161+
# Otherwise assign its index in the class_ids list
162+
else:
163+
label_map = lambda i: class_ids.index(i)
164+
165+
while True:
166+
batch_list = []
167+
label_list = []
168+
if class_to_batch_ratio and not classes_per_batch:
169+
classes_per_batch = int(class_to_batch_ratio * batch_size)
170+
if not classes_per_batch:
171+
batch_class_ids = [rand.choice(class_ids) for _ in range(batch_size)]
172+
else:
173+
batch_class_ids = rand.sample(class_ids, classes_per_batch)
174+
175+
for _ in range(batch_size):
176+
class_id = rand.choice(batch_class_ids)
177+
sample_ind = np.random.randint(data[class_id].shape[0])
178+
batch_list.append(data[class_id][sample_ind,...])
179+
# This works for the Stanford products dataset
180+
label_list.append(label_map(class_id))
181+
182+
yield(np.stack(batch_list), np.stack(label_list))

kml_layers.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,20 @@ class PairDistances(Layer):
8080
8181
"""
8282

83-
def __init__(self, epsilon=1e-6, **kwargs):
83+
def __init__(self, metric='l2', epsilon=1e-6, **kwargs):
8484
self.epsilon = epsilon
8585
super(PairDistances, self).__init__(**kwargs)
8686

8787
def build(self, input_shape):
8888
super(PairDistances, self).build(input_shape)
8989

9090
def call(self, x):
91-
dists = K.sqrt(K.relu(K.sum(K.square(x[0]-x[1]), axis=1))+self.epsilon)
91+
if metric == 'l2':
92+
dists = K.sqrt(K.relu(K.sum(K.square(x[0]-x[1]), axis=1))+self.epsilon)
93+
elif metric == 'l1':
94+
dists = K.sum(K.abs(x[0]-x[1]), axis=1)
95+
else:
96+
raise ValueError()
9297
return K.expand_dims(dists, axis=-1)
9398

9499

@@ -262,7 +267,7 @@ def call(self, x):
262267
F = K.tf.boolean_mask(F, K.tf.logical_not(K.cast(K.eye(2*self.p), K.tf.bool)))
263268
F = K.reshape(F, [2*self.p, 2*self.p-1])
264269

265-
return K.mean(K.categorical_crossentropy(target=self.labels, output=F, from_logits=True))
270+
return K.mean(K.categorical_crossentropy(target=self.labels, output=F, from_logits=True)) \
266271
+ self.reg_coeff * K.mean(embedding_norms)
267272

268273
def compute_output_shape(self, input_shape):

kml_utils.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from scipy.spatial.distance import pdist, squareform
66
from sklearn.cluster import KMeans
77
from sklearn.metrics import normalized_mutual_info_score
8+
from kml_data_utils import pair_generator, organize_by_class
89

910
def recall_at_k(embeddings, labels, k=1, metric='euclidean'):
1011
"""Computes the Recall@K metric
@@ -39,5 +40,11 @@ def nmi(embeddings, labels, metric='euclidean'):
3940
kmeans.fit(embeddings)
4041
return normalized_mutual_info_score(labels, kmeans.labels_)
4142

42-
def plot_distance_distributions(test_data, num_pairs=10000, num_bins=100):
43-
pass
43+
def plot_distance_distributions(embeddings, labels, num_pairs=10000, num_bins=100):
44+
negative_distances = []
45+
positive_distances = []
46+
pos_gen = pair_generator(organize_by_class(embeddings, labels), all_similar=True)
47+
while (len(positive_distances) < num_pairs/2):
48+
pairs = pos_gen.next()
49+
50+

0 commit comments

Comments
 (0)