1+ from ivis .data .knn import build_annoy_index , extract_knn
2+
3+ from annoy import AnnoyIndex
4+ import numpy as np
5+ from scipy .sparse import csr_matrix
6+ from sklearn import datasets
7+ import tempfile
8+ import pytest
9+ import os
10+
11+
12+ @pytest .fixture (scope = 'function' )
13+ def annoy_index_file ():
14+ _ , filepath = tempfile .mkstemp ('.index' )
15+ yield filepath
16+ os .remove (filepath )
17+
18+ def test_build_sparse_annoy_index (annoy_index_file ):
19+ data = np .random .choice ([0 , 1 ], size = (10 , 5 ))
20+ sparse_data = csr_matrix (data )
21+
22+ index = build_annoy_index (sparse_data , annoy_index_file )
23+ assert os .path .exists (annoy_index_file )
24+
25+ loaded_index = AnnoyIndex (5 )
26+ loaded_index .load (annoy_index_file )
27+
28+ assert index .f == loaded_index .f == 5
29+ assert index .get_n_items () == loaded_index .get_n_items () == 10
30+ assert index .get_nns_by_item (0 , 5 ) == loaded_index .get_nns_by_item (0 , 5 )
31+
32+
33+ def test_dense_annoy_index (annoy_index_file ):
34+ data = np .random .choice ([0 , 1 ], size = (10 , 5 ))
35+ index = build_annoy_index (data , annoy_index_file )
36+ assert os .path .exists (annoy_index_file )
37+
38+ loaded_index = AnnoyIndex (5 )
39+ loaded_index .load (annoy_index_file )
40+
41+ assert index .f == loaded_index .f == 5
42+ assert index .get_n_items () == loaded_index .get_n_items () == 10
43+ assert index .get_nns_by_item (0 , 5 ) == loaded_index .get_nns_by_item (0 , 5 )
44+
45+ def test_knn_retrieval ():
46+ annoy_index_filepath = 'tests/data/.test-annoy-index.index'
47+ expected_neighbour_list = np .load ('tests/data/test_knn_k4.npy' )
48+
49+ iris = datasets .load_iris ()
50+ X = iris .data
51+
52+ k = 4
53+ search_k = - 1
54+ neighbour_list = extract_knn (X , annoy_index_filepath , k = k , search_k = search_k )
55+
56+ assert np .all (expected_neighbour_list == neighbour_list )
0 commit comments