Skip to content

Commit

Permalink
Update tests to use data from fixtures instead of relying on sklearn
Browse files Browse the repository at this point in the history
  • Loading branch information
Szubie committed Sep 28, 2024
1 parent 2e6871a commit efd5a2f
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 161 deletions.
170 changes: 170 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
from tensorflow.keras import backend as K

Expand All @@ -11,3 +12,172 @@ def clear_session_after_test():
yield
if K.backend() == 'tensorflow' or K.backend() == 'cntk':
K.clear_session()

@pytest.fixture(scope="function")
def X():
data = np.array(
[[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.6, 1.4, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]]
)
return data

@pytest.fixture(scope="function")
def Y():
target = np.array(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
)
return target
10 changes: 2 additions & 8 deletions tests/data/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,10 @@ def test_dense_annoy_index(annoy_index_file):
loaded_index.unload()


def test_knn_retrieval():
def test_knn_retrieval(X):
annoy_index_filepath = 'tests/data/.test-annoy-index.index'
expected_neighbour_list = np.load('tests/data/test_knn_k3.npy')

iris = datasets.load_iris()
X = iris.data

k = 3
search_k = -1

Expand All @@ -79,13 +76,10 @@ def test_knn_matrix_construction_params(annoy_index_file):
for original_row, loaded_row in zip(index, loaded_index):
assert original_row == loaded_row

def test_knn_retrieval_non_verbose():
def test_knn_retrieval_non_verbose(X):
annoy_index_filepath = 'tests/data/.test-annoy-index.index'
expected_neighbour_list = np.load('tests/data/test_knn_k3.npy')

iris = datasets.load_iris()
X = iris.data

k = 3
search_k = -1

Expand Down
4 changes: 1 addition & 3 deletions tests/data/test_triplet_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from ivis.data.generators import UnsupervisedTripletGenerator


def test_UnsupervisedTripletGenerator():
def test_UnsupervisedTripletGenerator(X):
neighbour_list = np.load('tests/data/test_knn_k3.npy')

iris = datasets.load_iris()
X = iris.data
batch_size = 32

data_generator = UnsupervisedTripletGenerator(X, neighbour_list,
Expand Down
16 changes: 4 additions & 12 deletions tests/integration/test_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,16 @@
from sklearn import datasets


def test_iris_embedding():
iris = datasets.load_iris()
x = iris.data
y = iris.target

def test_iris_embedding(X):
ivis_iris = Ivis(epochs=5)
ivis_iris.k = 15
ivis_iris.batch_size = 16

y_pred_iris = ivis_iris.fit_transform(x)

def test_1d_iris_embedding():
iris = datasets.load_iris()
x = iris.data
y = iris.target
y_pred_iris = ivis_iris.fit_transform(X)

def test_1d_iris_embedding(X):
ivis_iris = Ivis(epochs=5, embedding_dims=1)
ivis_iris.k = 15
ivis_iris.batch_size = 16

y_pred_iris = ivis_iris.fit_transform(x)
y_pred_iris = ivis_iris.fit_transform(X)
11 changes: 4 additions & 7 deletions tests/integration/test_neighbour_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
from ivis import Ivis


def test_custom_ndarray_neighbour_matrix():
iris = datasets.load_iris()
x = iris.data
y = iris.target
def test_custom_ndarray_neighbour_matrix(X, Y):

class_indicies = {label: np.argwhere(y == label).ravel() for label in np.unique(y)}
neighbour_matrix = np.array([class_indicies[label] for label in y])
class_indicies = {label: np.argwhere(Y == label).ravel() for label in np.unique(Y)}
neighbour_matrix = np.array([class_indicies[label] for label in Y])

ivis_iris = Ivis(epochs=5, neighbour_matrix=neighbour_matrix)
ivis_iris.k = 15
ivis_iris.batch_size = 16

y_pred_iris = ivis_iris.fit_transform(x)
y_pred_iris = ivis_iris.fit_transform(X)
50 changes: 18 additions & 32 deletions tests/integration/test_semi-supervised_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,69 +4,55 @@
from ivis import Ivis


def test_iris_embedding():
iris = datasets.load_iris()
x = iris.data
y = iris.target
mask = np.random.choice(range(len(y)), size=len(y) // 2, replace=False)
y[mask] = -1
def test_iris_embedding(X, Y):
mask = np.random.choice(range(len(Y)), size=len(Y) // 2, replace=False)
Y[mask] = -1

ivis_iris = Ivis(epochs=5)
ivis_iris.k = 15
ivis_iris.batch_size = 16

y_pred_iris = ivis_iris.fit_transform(x, y)
y_pred_iris = ivis_iris.fit_transform(X, Y)

def test_correctly_indexed_semi_supervised_classificaton_classes():
iris = datasets.load_iris()
x = iris.data
y = iris.target
def test_correctly_indexed_semi_supervised_classificaton_classes(X, Y):

# Mark points as unlabeled
mask = np.random.choice(range(len(y)), size=len(y) // 2, replace=False)
y[mask] = -1
mask = np.random.choice(range(len(Y)), size=len(Y) // 2, replace=False)
Y[mask] = -1

supervision_metric = 'sparse_categorical_crossentropy'
ivis_iris = Ivis(k=15, batch_size=16, epochs=5,
supervision_metric=supervision_metric)

embeddings = ivis_iris.fit_transform(x, y)

def test_non_zero_indexed_semi_supervised_classificaton_classes():
iris = datasets.load_iris()
x = iris.data
y = iris.target
embeddings = ivis_iris.fit_transform(X, Y)

def test_non_zero_indexed_semi_supervised_classificaton_classes(X, Y):
# Make labels non-zero indexed
y = y + 1
Y = Y + 1

# Mark points as unlabeled
mask = np.random.choice(range(len(y)), size=len(y) // 2, replace=False)
y[mask] = -1
mask = np.random.choice(range(len(Y)), size=len(Y) // 2, replace=False)
Y[mask] = -1

supervision_metric = 'sparse_categorical_crossentropy'
ivis_iris = Ivis(k=15, batch_size=16, epochs=5,
supervision_metric=supervision_metric)

with pytest.raises(ValueError):
embeddings = ivis_iris.fit_transform(x, y)

embeddings = ivis_iris.fit_transform(X, Y)

def test_non_consecutive_indexed_semi_supervised_classificaton_classes():
iris = datasets.load_iris()
x = iris.data
y = iris.target

def test_non_consecutive_indexed_semi_supervised_classificaton_classes(X, Y):
# Make labels non-consecutive indexed
y[y == max(y)] = max(y) + 1
Y[Y == max(Y)] = max(Y) + 1

# Mark points as unlabeled
mask = np.random.choice(range(len(y)), size=len(y) // 2, replace=False)
y[mask] = -1
mask = np.random.choice(range(len(Y)), size=len(Y) // 2, replace=False)
Y[mask] = -1

supervision_metric = 'sparse_categorical_crossentropy'
ivis_iris = Ivis(k=15, batch_size=16, epochs=5,
supervision_metric=supervision_metric)

with pytest.raises(ValueError):
embeddings = ivis_iris.fit_transform(x, y)
embeddings = ivis_iris.fit_transform(X, Y)
Loading

0 comments on commit efd5a2f

Please sign in to comment.