Skip to content

Commit 5325a28

Browse files
model option added for img2vec
1 parent fbd37f5 commit 5325a28

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

tasnif/calculations.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@
66
from .logger import info
77

88

9-
def get_embeddings(use_gpu=False, images=None):
9+
def get_embeddings(use_gpu=False, images=None, model='resnet-18'):
1010
"""
1111
This Python function initializes an Img2Vec object, runs it on either GPU or CPU, and retrieves
1212
image embeddings.
13+
:param use_gpu: The `use_gpu` parameter is a boolean that specifies whether to use GPU or CPU.
14+
:param images: The `images` parameter is a list of image paths to be used for generating embeddings.
15+
:param model: The `model` parameter is a string that specifies the model to use for generating.
16+
For available models, see https://github.com/christiansafka/img2vec
17+
:return: The function `get_embeddings` returns the embeddings of the images as np.ndarray.
1318
"""
1419

1520
info(f"Img2Vec is running on {'GPU' if use_gpu else 'CPU'}...")
16-
img2vec = Img2Vec(cuda=use_gpu)
17-
21+
img2vec = Img2Vec(cuda=use_gpu, model=model)
22+
print(f"Using model: {model}")
1823
embeddings = img2vec.get_vec(images, tensor=False)
1924
return embeddings
2025

tasnif/tasnif.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,23 @@ def read(self, folder_path):
4444
self.image_paths = read_images_from_directory(folder_path)
4545
self.images = read_with_pil(self.image_paths)
4646

47-
def calculate(self, pca=True, iter=10):
47+
def calculate(self, pca=True, iter=10, model="resnet-18"):
4848
"""
4949
The function calculates embeddings, performs PCA, and applies K-means clustering to the
5050
embeddings. It will not perform these operations if no images have been read.
5151
5252
:param pca: The `pca` parameter is a boolean that specifies whether to perform PCA or not. Default is True
5353
:param iter: The `iter` parameter is an integer that specifies the number of iterations for the KMeans algorithm. Default is 10.
54+
:param model: The `model` parameter is a string that specifies the model to use for generating embeddings. Default is 'resnet-18'.
55+
For available models, see https://github.com/christiansafka/img2vec
5456
"""
5557

5658
if not self.images:
5759
raise ValueError(
5860
"The images list can not be empty. Please call the read method before calculating."
5961
)
6062

61-
self.embeddings = get_embeddings(use_gpu=self.use_gpu, images=self.images)
63+
self.embeddings = get_embeddings(use_gpu=self.use_gpu, images=self.images, model=model)
6264
if pca:
6365
self.pca_embeddings = calculate_pca(self.embeddings, self.pca_dim)
6466
self.centroid, self.labels, self.counts = calculate_kmeans(

0 commit comments

Comments
 (0)