|
71 | 71 | "outputs": [], |
72 | 72 | "source": [ |
73 | 73 | "# Load the data\n", |
74 | | - "data_loader = load_mnist_data(save_path=\"./data/mnist\", num_samples=640, batch_size=64)" |
| 74 | + "data_loader = load_mnist_data(\n", |
| 75 | + " save_path=\"./data/mnist\", num_samples=640, batch_size=64\n", |
| 76 | + ")" |
75 | 77 | ] |
76 | 78 | }, |
77 | 79 | { |
|
119 | 121 | "outputs": [], |
120 | 122 | "source": [ |
121 | 123 | "# Load the data\n", |
122 | | - "data_loader = load_mnist_data(save_path=\"./data/mnist\", num_samples=640, batch_size=64)" |
| 124 | + "data_loader = load_mnist_data(\n", |
| 125 | + " save_path=\"./data/mnist\", num_samples=640, batch_size=64\n", |
| 126 | + ")" |
123 | 127 | ] |
124 | 128 | }, |
125 | 129 | { |
|
170 | 174 | " and \"CUDAExecutionProvider\" in onnxruntime.get_available_providers()\n", |
171 | 175 | " else None\n", |
172 | 176 | ")\n", |
173 | | - "ort_session = onnxruntime.InferenceSession(\"./data/resnet50-1.onnx\", providers=provider)\n", |
| 177 | + "ort_session = onnxruntime.InferenceSession(\n", |
| 178 | + " \"./data/resnet50-1.onnx\", providers=provider\n", |
| 179 | + ")\n", |
174 | 180 | "\n", |
175 | 181 | "# Extract embeddings from the dataset\n", |
176 | 182 | "embeddings, labels = extract_embeddings_onnx(\n", |
|
232 | 238 | "# Insert the embeddings into the collection\n", |
233 | 239 | "uuids = [str(uuid.uuid5(uuid.NAMESPACE_DNS, str(e))) for e in embeddings]\n", |
234 | 240 | "label_list_dict = [{\"label\": label} for label in labels]\n", |
235 | | - "weaviate_api.insert_embeddings(uuids, embeddings, label_list_dict, batch_size=50)" |
| 241 | + "weaviate_api.insert_embeddings(\n", |
| 242 | + " uuids, embeddings, label_list_dict, batch_size=50\n", |
| 243 | + ")" |
236 | 244 | ] |
237 | 245 | }, |
238 | 246 | { |
|
254 | 262 | ], |
255 | 263 | "source": [ |
256 | 264 | "# Search for the nearest neighbors\n", |
257 | | - "search_uuids, scores = weaviate_api.search_similar_embeddings(embeddings[0], top_k=5)\n", |
258 | | - "payloads = weaviate_api.retrieve_payloads_by_ids(search_uuids, properties=[\"label\"])\n", |
| 265 | + "search_uuids, scores = weaviate_api.search_similar_embeddings(\n", |
| 266 | + " embeddings[0], top_k=5\n", |
| 267 | + ")\n", |
| 268 | + "payloads = weaviate_api.retrieve_payloads_by_ids(\n", |
| 269 | + " search_uuids, properties=[\"label\"]\n", |
| 270 | + ")\n", |
259 | 271 | "\n", |
260 | 272 | "# Print the search results\n", |
261 | 273 | "for u, p in zip(search_uuids, payloads):\n", |
|
414 | 426 | " and \"CUDAExecutionProvider\" in onnxruntime.get_available_providers()\n", |
415 | 427 | " else None\n", |
416 | 428 | ")\n", |
417 | | - "ort_session = onnxruntime.InferenceSession(\"./data/resnet50-1.onnx\", providers=provider)" |
| 429 | + "ort_session = onnxruntime.InferenceSession(\n", |
| 430 | + " \"./data/resnet50-1.onnx\", providers=provider\n", |
| 431 | + ")" |
418 | 432 | ] |
419 | 433 | }, |
420 | 434 | { |
|
555 | 569 | ], |
556 | 570 | "source": [ |
557 | 571 | "# Search for the nearest neighbors\n", |
558 | | - "search_uuids, scores = weaviate_api.search_similar_embeddings(first_emb, top_k=5)\n", |
559 | | - "payloads = weaviate_api.retrieve_payloads_by_ids(search_uuids, properties=[\"label\"])\n", |
| 572 | + "search_uuids, scores = weaviate_api.search_similar_embeddings(\n", |
| 573 | + " first_emb, top_k=5\n", |
| 574 | + ")\n", |
| 575 | + "payloads = weaviate_api.retrieve_payloads_by_ids(\n", |
| 576 | + " search_uuids, properties=[\"label\"]\n", |
| 577 | + ")\n", |
560 | 578 | "\n", |
561 | 579 | "# Print the search results\n", |
562 | 580 | "for u, p, s in zip(search_uuids, payloads, scores):\n", |
|
743 | 761 | "source": [ |
744 | 762 | "# Setup Weaviate\n", |
745 | 763 | "weaviate_api = WeaviateAPI(\"http://localhost:8080\")\n", |
746 | | - "weaviate_api.create_collection(collection_name=\"Mnist_LFS\", properties=[\"image_path\"])" |
| 764 | + "weaviate_api.create_collection(\n", |
| 765 | + " collection_name=\"Mnist_LFS\", properties=[\"image_path\"]\n", |
| 766 | + ")" |
747 | 767 | ] |
748 | 768 | }, |
749 | 769 | { |
|
784 | 804 | ], |
785 | 805 | "source": [ |
786 | 806 | "# Search for the nearest neighbors\n", |
787 | | - "search_uuids, scores = weaviate_api.search_similar_embeddings(embeddings[0], top_k=5)\n", |
| 807 | + "search_uuids, scores = weaviate_api.search_similar_embeddings(\n", |
| 808 | + " embeddings[0], top_k=5\n", |
| 809 | + ")\n", |
788 | 810 | "\n", |
789 | 811 | "# Print the search results\n", |
790 | 812 | "for u, s in zip(search_uuids, scores):\n", |
|
0 commit comments