Skip to content

Commit 2edab2f

Browse files
committed
reformatted
1 parent 33dbd1b commit 2edab2f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1738
-829
lines changed

examples/Data_Custom_Example.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@
240240
" {\n",
241241
" \"person\": {\n",
242242
" \"labels\": data[\"categories\"][0][\"keypoints\"],\n",
243-
" \"edges\": (np.array(data[\"categories\"][0][\"skeleton\"]) - 1).tolist(),\n",
243+
" \"edges\": (\n",
244+
" np.array(data[\"categories\"][0][\"skeleton\"]) - 1\n",
245+
" ).tolist(),\n",
244246
" }\n",
245247
" },\n",
246248
" task=\"keypoints\",\n",
@@ -307,7 +309,9 @@
307309
" for kp in kps:\n",
308310
" kp = kp[1:].reshape(-1, 3)\n",
309311
" for k in kp:\n",
310-
" cv2.circle(image, (int(k[0] * w), int(k[1] * h)), 2, (0, 255, 0), 2)\n",
312+
" cv2.circle(\n",
313+
" image, (int(k[0] * w), int(k[1] * h)), 2, (0, 255, 0), 2\n",
314+
" )\n",
311315
"\n",
312316
" plt.imshow(image)\n",
313317
" plt.axis(\"off\") # Optional: Hide axis\n",

examples/Data_Parser_Example.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@
8282
"outputs": [],
8383
"source": [
8484
"dataset_name = \"coco_test\"\n",
85-
"parser = LuxonisParser(dataset_dir, dataset_name=dataset_name, delete_existing=True)\n",
85+
"parser = LuxonisParser(\n",
86+
" dataset_dir, dataset_name=dataset_name, delete_existing=True\n",
87+
")\n",
8688
"dataset = parser.parse(random_split=True)"
8789
]
8890
},
@@ -125,7 +127,9 @@
125127
" for kp in kps:\n",
126128
" kp = kp[1:].reshape(-1, 3)\n",
127129
" for k in kp:\n",
128-
" cv2.circle(image, (int(k[0] * w), int(k[1] * h)), 2, (0, 255, 0), 2)\n",
130+
" cv2.circle(\n",
131+
" image, (int(k[0] * w), int(k[1] * h)), 2, (0, 255, 0), 2\n",
132+
" )\n",
129133
"\n",
130134
" plt.imshow(image)\n",
131135
" plt.axis(\"off\") # Optional: Hide axis\n",

examples/Embeddings_LDF_Qdrant_Example.ipynb

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@
6060
"outputs": [],
6161
"source": [
6262
"# Load the data\n",
63-
"data_loader = load_mnist_data(save_path=\"./data/mnist\", num_samples=640, batch_size=64)"
63+
"data_loader = load_mnist_data(\n",
64+
" save_path=\"./data/mnist\", num_samples=640, batch_size=64\n",
65+
")"
6466
]
6567
},
6668
{
@@ -108,7 +110,9 @@
108110
"outputs": [],
109111
"source": [
110112
"# Load the data\n",
111-
"data_loader = load_mnist_data(save_path=\"./data/mnist\", num_samples=640, batch_size=64)"
113+
"data_loader = load_mnist_data(\n",
114+
" save_path=\"./data/mnist\", num_samples=640, batch_size=64\n",
115+
")"
112116
]
113117
},
114118
{
@@ -159,7 +163,9 @@
159163
" and \"CUDAExecutionProvider\" in onnxruntime.get_available_providers()\n",
160164
" else None\n",
161165
")\n",
162-
"ort_session = onnxruntime.InferenceSession(\"./data/resnet50-1.onnx\", providers=provider)\n",
166+
"ort_session = onnxruntime.InferenceSession(\n",
167+
" \"./data/resnet50-1.onnx\", providers=provider\n",
168+
")\n",
163169
"\n",
164170
"# Extract embeddings from the dataset\n",
165171
"embeddings, labels = extract_embeddings_onnx(\n",
@@ -420,7 +426,9 @@
420426
" and \"CUDAExecutionProvider\" in onnxruntime.get_available_providers()\n",
421427
" else None\n",
422428
")\n",
423-
"ort_session = onnxruntime.InferenceSession(\"./data/resnet50-1.onnx\", providers=provider)"
429+
"ort_session = onnxruntime.InferenceSession(\n",
430+
" \"./data/resnet50-1.onnx\", providers=provider\n",
431+
")"
424432
]
425433
},
426434
{
@@ -446,7 +454,9 @@
446454
"\n",
447455
"# Create a collection\n",
448456
"qdrant_api.create_collection(\n",
449-
" collection_name=\"Mnist_LDF\", properties=[\"label\", \"image_path\"], vector_size=2048\n",
457+
" collection_name=\"Mnist_LDF\",\n",
458+
" properties=[\"label\", \"image_path\"],\n",
459+
" vector_size=2048,\n",
450460
")"
451461
]
452462
},

examples/Embeddings_LDF_Weaviate_Example.ipynb

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@
7171
"outputs": [],
7272
"source": [
7373
"# Load the data\n",
74-
"data_loader = load_mnist_data(save_path=\"./data/mnist\", num_samples=640, batch_size=64)"
74+
"data_loader = load_mnist_data(\n",
75+
" save_path=\"./data/mnist\", num_samples=640, batch_size=64\n",
76+
")"
7577
]
7678
},
7779
{
@@ -119,7 +121,9 @@
119121
"outputs": [],
120122
"source": [
121123
"# Load the data\n",
122-
"data_loader = load_mnist_data(save_path=\"./data/mnist\", num_samples=640, batch_size=64)"
124+
"data_loader = load_mnist_data(\n",
125+
" save_path=\"./data/mnist\", num_samples=640, batch_size=64\n",
126+
")"
123127
]
124128
},
125129
{
@@ -170,7 +174,9 @@
170174
" and \"CUDAExecutionProvider\" in onnxruntime.get_available_providers()\n",
171175
" else None\n",
172176
")\n",
173-
"ort_session = onnxruntime.InferenceSession(\"./data/resnet50-1.onnx\", providers=provider)\n",
177+
"ort_session = onnxruntime.InferenceSession(\n",
178+
" \"./data/resnet50-1.onnx\", providers=provider\n",
179+
")\n",
174180
"\n",
175181
"# Extract embeddings from the dataset\n",
176182
"embeddings, labels = extract_embeddings_onnx(\n",
@@ -232,7 +238,9 @@
232238
"# Insert the embeddings into the collection\n",
233239
"uuids = [str(uuid.uuid5(uuid.NAMESPACE_DNS, str(e))) for e in embeddings]\n",
234240
"label_list_dict = [{\"label\": label} for label in labels]\n",
235-
"weaviate_api.insert_embeddings(uuids, embeddings, label_list_dict, batch_size=50)"
241+
"weaviate_api.insert_embeddings(\n",
242+
" uuids, embeddings, label_list_dict, batch_size=50\n",
243+
")"
236244
]
237245
},
238246
{
@@ -254,8 +262,12 @@
254262
],
255263
"source": [
256264
"# Search for the nearest neighbors\n",
257-
"search_uuids, scores = weaviate_api.search_similar_embeddings(embeddings[0], top_k=5)\n",
258-
"payloads = weaviate_api.retrieve_payloads_by_ids(search_uuids, properties=[\"label\"])\n",
265+
"search_uuids, scores = weaviate_api.search_similar_embeddings(\n",
266+
" embeddings[0], top_k=5\n",
267+
")\n",
268+
"payloads = weaviate_api.retrieve_payloads_by_ids(\n",
269+
" search_uuids, properties=[\"label\"]\n",
270+
")\n",
259271
"\n",
260272
"# Print the search results\n",
261273
"for u, p in zip(search_uuids, payloads):\n",
@@ -414,7 +426,9 @@
414426
" and \"CUDAExecutionProvider\" in onnxruntime.get_available_providers()\n",
415427
" else None\n",
416428
")\n",
417-
"ort_session = onnxruntime.InferenceSession(\"./data/resnet50-1.onnx\", providers=provider)"
429+
"ort_session = onnxruntime.InferenceSession(\n",
430+
" \"./data/resnet50-1.onnx\", providers=provider\n",
431+
")"
418432
]
419433
},
420434
{
@@ -555,8 +569,12 @@
555569
],
556570
"source": [
557571
"# Search for the nearest neighbors\n",
558-
"search_uuids, scores = weaviate_api.search_similar_embeddings(first_emb, top_k=5)\n",
559-
"payloads = weaviate_api.retrieve_payloads_by_ids(search_uuids, properties=[\"label\"])\n",
572+
"search_uuids, scores = weaviate_api.search_similar_embeddings(\n",
573+
" first_emb, top_k=5\n",
574+
")\n",
575+
"payloads = weaviate_api.retrieve_payloads_by_ids(\n",
576+
" search_uuids, properties=[\"label\"]\n",
577+
")\n",
560578
"\n",
561579
"# Print the search results\n",
562580
"for u, p, s in zip(search_uuids, payloads, scores):\n",
@@ -743,7 +761,9 @@
743761
"source": [
744762
"# Setup Weaviate\n",
745763
"weaviate_api = WeaviateAPI(\"http://localhost:8080\")\n",
746-
"weaviate_api.create_collection(collection_name=\"Mnist_LFS\", properties=[\"image_path\"])"
764+
"weaviate_api.create_collection(\n",
765+
" collection_name=\"Mnist_LFS\", properties=[\"image_path\"]\n",
766+
")"
747767
]
748768
},
749769
{
@@ -784,7 +804,9 @@
784804
],
785805
"source": [
786806
"# Search for the nearest neighbors\n",
787-
"search_uuids, scores = weaviate_api.search_similar_embeddings(embeddings[0], top_k=5)\n",
807+
"search_uuids, scores = weaviate_api.search_similar_embeddings(\n",
808+
" embeddings[0], top_k=5\n",
809+
")\n",
788810
"\n",
789811
"# Print the search results\n",
790812
"for u, s in zip(search_uuids, scores):\n",

examples/Embeddings_Processing_Example.ipynb

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@
103103
"source": [
104104
"desired_size = int(len(embeddings) * 0.05)\n",
105105
"# desired_size = 10\n",
106-
"selected_image_indices = find_representative_kmedoids(similarity_matrix, desired_size)\n",
106+
"selected_image_indices = find_representative_kmedoids(\n",
107+
" similarity_matrix, desired_size\n",
108+
")\n",
107109
"# selected_image_indices = find_representative_greedy_qdrant(qdrant_client, desired_size, 0, \"mnist3\")"
108110
]
109111
},
@@ -356,7 +358,9 @@
356358
"metadata": {},
357359
"outputs": [],
358360
"source": [
359-
"mis_img_paths = qdrant_api.retrieve_payloads_by_ids(missing_img_uuids, [\"image_path\"])\n",
361+
"mis_img_paths = qdrant_api.retrieve_payloads_by_ids(\n",
362+
" missing_img_uuids, [\"image_path\"]\n",
363+
")\n",
360364
"mis_img_paths = [x[\"image_path\"] for x in mis_img_paths]"
361365
]
362366
},
@@ -462,10 +466,14 @@
462466
"# find representative images\n",
463467
"desired_size = int(len(embeddings) * 0.05)\n",
464468
"similarity_matrix = calculate_similarity_matrix(embeddings)\n",
465-
"selected_image_indices = find_representative_kmedoids(similarity_matrix, desired_size)\n",
469+
"selected_image_indices = find_representative_kmedoids(\n",
470+
" similarity_matrix, desired_size\n",
471+
")\n",
466472
"\n",
467473
"selcted_ids = np.array(ids)[selected_image_indices].tolist()\n",
468-
"represent_img_paths = w_api.retrieve_payloads_by_ids(selcted_ids, [\"image_path\"])\n",
474+
"represent_img_paths = w_api.retrieve_payloads_by_ids(\n",
475+
" selcted_ids, [\"image_path\"]\n",
476+
")\n",
469477
"represent_img_paths = [x[\"image_path\"] for x in represent_img_paths]\n",
470478
"\n",
471479
"# plot\n",
@@ -626,7 +634,9 @@
626634
"mis_ix, new_y = find_mismatches_centroids(X, y)\n",
627635
"\n",
628636
"missing_img_uuids = np.array(ids)[mis_ix].tolist()\n",
629-
"mis_img_paths = w_api.retrieve_payloads_by_ids(missing_img_uuids, [\"image_path\"])\n",
637+
"mis_img_paths = w_api.retrieve_payloads_by_ids(\n",
638+
" missing_img_uuids, [\"image_path\"]\n",
639+
")\n",
630640
"mis_img_paths = [x[\"image_path\"] for x in mis_img_paths]\n",
631641
"\n",
632642
"# plot\n",

examples/utils/data_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Note: This loader is particularly useful when you want to use MNIST data with models that were
2828
pre-trained on datasets like ImageNet and expect 3-channel RGB input.
2929
"""
30+
3031
import torch
3132
import torchvision
3233
import torchvision.transforms as transforms
@@ -35,15 +36,18 @@
3536
def mnist_transformations() -> transforms.Compose:
3637
"""Returns composed transformations for the MNIST dataset.
3738
38-
Transforms the images from 1 channel grayscale to 3 channels RGB and resizes them.
39+
Transforms the images from 1 channel grayscale to 3 channels RGB and
40+
resizes them.
3941
"""
4042
return transforms.Compose(
4143
[
4244
transforms.Grayscale(num_output_channels=3),
4345
transforms.Lambda(lambda x: x.convert("RGB")),
4446
transforms.Resize((224, 224)),
4547
transforms.ToTensor(),
46-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
48+
transforms.Normalize(
49+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
50+
),
4751
]
4852
)
4953

@@ -70,7 +74,9 @@ def load_mnist_data(
7074
)
7175

7276
# If num_samples is set to -1, use the entire dataset
73-
num_samples = min(num_samples, len(dataset)) if num_samples != -1 else len(dataset)
77+
num_samples = (
78+
min(num_samples, len(dataset)) if num_samples != -1 else len(dataset)
79+
)
7480

7581
# Create a subset of the dataset using Subset class
7682
subset = torch.utils.data.Subset(dataset, torch.arange(num_samples))

examples/utils/torch_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
# PyTorch and ONNX model loading and exporting functions
1313
def load_model_resnet50(discard_last_layer: bool) -> nn.Module:
14-
"""Load a pre-trained ResNet-50 model with the last fully connected layer
15-
removed."""
14+
"""Load a pre-trained ResNet-50 model with the last fully connected
15+
layer removed."""
1616
model = models.resnet50(weights=resnet.ResNet50_Weights.IMAGENET1K_V1)
1717
if discard_last_layer:
1818
model = nn.Sequential(
@@ -83,7 +83,9 @@ def save_embeddings(
8383
torch.save(labels, save_path + "labels.pth")
8484

8585

86-
def load_embeddings(save_path: str = "./") -> Tuple[torch.Tensor, torch.Tensor]:
86+
def load_embeddings(
87+
save_path: str = "./",
88+
) -> Tuple[torch.Tensor, torch.Tensor]:
8789
"""Load embeddings and labels tensors from the specified path."""
8890
embeddings = torch.load(save_path + "embeddings.pth")
8991
labels = torch.load(save_path + "labels.pth")
@@ -99,18 +101,21 @@ def generate_new_embeddings(
99101
emb_batch_size: int = 64,
100102
transform: transforms.Compose = None,
101103
):
102-
"""Generate embeddings for new images using a given ONNX runtime session.
104+
"""Generate embeddings for new images using a given ONNX runtime
105+
session.
103106
104107
@type img_paths: List[str]
105108
@param img_paths: List of image paths for new images.
106109
@type ort_session: L{InferenceSession}
107110
@param ort_session: ONNX runtime session.
108111
@type output_layer_name: str
109-
@param output_layer_name: Name of the output layer in the ONNX model.
112+
@param output_layer_name: Name of the output layer in the ONNX
113+
model.
110114
@type emb_batch_size: int
111115
@param emb_batch_size: Batch size for generating embeddings.
112116
@type transform: torchvision.transforms
113-
@param transform: Optional torchvision transform for preprocessing images.
117+
@param transform: Optional torchvision transform for preprocessing
118+
images.
114119
@rtype: List[List[float]]
115120
@return: List of embeddings for the new images.
116121
"""
@@ -141,7 +146,9 @@ def generate_new_embeddings(
141146
batch_tensor = torch.stack(batch_tensors).cuda()
142147

143148
# Run the ONNX model on the batch
144-
ort_inputs = {ort_session.get_inputs()[0].name: batch_tensor.cpu().numpy()}
149+
ort_inputs = {
150+
ort_session.get_inputs()[0].name: batch_tensor.cpu().numpy()
151+
}
145152
ort_outputs = ort_session.run([output_layer_name], ort_inputs)
146153

147154
# Append the embeddings from the batch to the new_embeddings list

luxonis_ml/data/__main__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def get_dataset_info(name: str) -> Tuple[int, List[str], List[str]]:
5959
def print_info(name: str) -> None:
6060
dataset = LuxonisDataset(name)
6161
_, classes = dataset.get_classes()
62-
table = Table(title="Classes", box=rich.box.ROUNDED, row_styles=["yellow", "cyan"])
62+
table = Table(
63+
title="Classes", box=rich.box.ROUNDED, row_styles=["yellow", "cyan"]
64+
)
6365
table.add_column("Task", header_style="magenta i", max_width=30)
6466
table.add_column("Class Names", header_style="magenta i", max_width=50)
6567
for task, c in classes.items():

0 commit comments

Comments
 (0)