File tree 2 files changed +12
-5
lines changed
2 files changed +12
-5
lines changed Original file line number Diff line number Diff line change 6
6
from .logger import info
7
7
8
8
9
- def get_embeddings (use_gpu = False , images = None ):
9
+ def get_embeddings (use_gpu = False , images = None , model = 'resnet-18' ):
10
10
"""
11
11
This Python function initializes an Img2Vec object, runs it on either GPU or CPU, and retrieves
12
12
image embeddings.
13
+ :param use_gpu: The `use_gpu` parameter is a boolean that specifies whether to use GPU or CPU.
14
+ :param images: The `images` parameter is a list of image paths to be used for generating embeddings.
15
+ :param model: The `model` parameter is a string that specifies the model to use for generating.
16
+ For available models, see https://github.com/christiansafka/img2vec
17
+ :return: The function `get_embeddings` returns the embeddings of the images as np.ndarray.
13
18
"""
14
19
15
20
info (f"Img2Vec is running on { 'GPU' if use_gpu else 'CPU' } ..." )
16
- img2vec = Img2Vec (cuda = use_gpu )
17
-
21
+ img2vec = Img2Vec (cuda = use_gpu , model = model )
22
+ print ( f"Using model: { model } " )
18
23
embeddings = img2vec .get_vec (images , tensor = False )
19
24
return embeddings
20
25
Original file line number Diff line number Diff line change @@ -44,21 +44,23 @@ def read(self, folder_path):
44
44
self .image_paths = read_images_from_directory (folder_path )
45
45
self .images = read_with_pil (self .image_paths )
46
46
47
- def calculate (self , pca = True , iter = 10 ):
47
+ def calculate (self , pca = True , iter = 10 , model = "resnet-18" ):
48
48
"""
49
49
The function calculates embeddings, performs PCA, and applies K-means clustering to the
50
50
embeddings. It will not perform these operations if no images have been read.
51
51
52
52
:param pca: The `pca` parameter is a boolean that specifies whether to perform PCA or not. Default is True
53
53
:param iter: The `iter` parameter is an integer that specifies the number of iterations for the KMeans algorithm. Default is 10.
54
+ :param model: The `model` parameter is a string that specifies the model to use for generating embeddings. Default is 'resnet-18'.
55
+ For available models, see https://github.com/christiansafka/img2vec
54
56
"""
55
57
56
58
if not self .images :
57
59
raise ValueError (
58
60
"The images list can not be empty. Please call the read method before calculating."
59
61
)
60
62
61
- self .embeddings = get_embeddings (use_gpu = self .use_gpu , images = self .images )
63
+ self .embeddings = get_embeddings (use_gpu = self .use_gpu , images = self .images , model = model )
62
64
if pca :
63
65
self .pca_embeddings = calculate_pca (self .embeddings , self .pca_dim )
64
66
self .centroid , self .labels , self .counts = calculate_kmeans (
You can’t perform that action at this time.
0 commit comments