-
Notifications
You must be signed in to change notification settings - Fork 581
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1376 from reina-w/image-demo
Image demo
- Loading branch information
Showing
11 changed files
with
358 additions
and
0 deletions.
There are no files selected for viewing
7 changes: 7 additions & 0 deletions
7
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/.env
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
MILVUS_ENDPOINT=./milvus_demo.db | ||
# MILVUS_TOKEN=************ | ||
|
||
COLLECTION_NAME=my_image_collection | ||
|
||
MODEL_NAME=resnet34 | ||
MODEL_DIM=512 |
116 changes: 116 additions & 0 deletions
116
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Image Similarity Search with Milvus 🖼️ | ||
|
||
This demo implements an image similarity search application using Streamlit, Milvus, and a pre-trained ResNet model. Users can upload an image, crop it to focus on the region of interest, and search for similar images from a pre-built database. | ||
|
||
## Features | ||
- Upload and crop images to define the region of interest. | ||
- Extract features using a pre-trained ResNet model. | ||
- Search for similar images using Milvus for efficient similarity search. | ||
- Display search results along with similarity scores. | ||
|
||
## Quick Deploy | ||
|
||
Follow these steps to quickly deploy the application locally: | ||
|
||
### Preparation | ||
|
||
> Prerequisites: Python 3.8 or higher | ||
**1. Download Codes** | ||
```bash | ||
$ git clone <https://github.com/milvus-io/bootcamp.git> | ||
$ cd bootcamp/bootcamp/tutorials/quickstart/app/image_search_with_milvus | ||
``` | ||
|
||
**2. Set Environment** | ||
|
||
- Install dependencies | ||
|
||
```bash | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
- Set environment variables | ||
|
||
Modify the environment file [.env](./.env) to change environment variables: | ||
|
||
``` | ||
MILVUS_ENDPOINT=./milvus_demo.db | ||
# MILVUS_TOKEN=************ | ||
|
||
COLLECTION_NAME=my_image_collection | ||
``` | ||
|
||
- `MILVUS_ENDPOINT`: The URI to connect Milvus or Zilliz Cloud service. By default, we use a local file "./milvus_demo.db" for convenience, as it automatically utilizes [Milvus Lite](https://milvus.io/docs/milvus_lite.md) to store all data at local. | ||
> - If you have large scale of data, you can set up a more performant Milvus server on docker or kubernetes. In this setup, please use the server uri, e.g. http://localhost:19530, as your uri. | ||
> | ||
> - If you want to use Zilliz Cloud, the fully managed cloud service for Milvus, adjust the uri and token, which correspond to the Public Endpoint and Api key in Zilliz Cloud. | ||
- `MILVUS_TOKEN`: This is optional. Uncomment this line to enter your password if authentication is required by your Milvus or Zilliz Cloud service. | ||
- `COLLECTION_NAME`: The collection name in Milvus database, defaults to "my_image_collection". | ||
- `MODEL_NAME`: The name of pretrained image embedding model, defaults to "resnet34". | ||
- `MODEL_DIM`: The embedding dimension, which should change according to the MODEL_NAME. | ||
|
||
**3. Prepare Data** | ||
|
||
We are using [a subset of ImageNet](https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip) for this demo, which includes approximately 200 categories with images of animals, objects, buildings, and more.<br> | ||
|
||
```bash | ||
$ wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip | ||
$ unzip -q reverse_image_search.zip -d reverse_image_search | ||
``` | ||
|
||
Create a collection named as the environment variable `COLLECTION_NAME` and load data from "reverse_image_search/train" to get the knowledge ready by running the [insert.py](./insert.py). Please note we will use the JPEG images only from the "reverse_image_search/train" to build the database. There will be about 1000 images. | ||
|
||
```bash | ||
$ python insert.py reverse_image_search/train | ||
``` | ||
|
||
> **Note:** If the collection exists in the Milvus database, then it will be dropped first to create a new one. | ||
|
||
### Start Service | ||
|
||
Run the Streamlit application: | ||
|
||
```bash | ||
$ streamlit run app.py | ||
``` | ||
|
||
### Example Usage | ||
|
||
**Step 1:** Choose an image file to upload (JPEG format). | ||
<div style="text-align: center;"> | ||
<figure> | ||
<img src="./pics/step1.png" alt="Description of Image" width="700"/> | ||
</figure> | ||
</div> | ||
|
||
**Step 2:** Crop the image to focus on the region of interest. | ||
**Step 3:** Set the desired number of top-k results to display using the slider. | ||
<div style="text-align: center;"> | ||
<figure> | ||
<img src="./pics/step2_and_3.jpg" alt="Description of Image" width="700"/> | ||
</figure> | ||
</div> | ||
|
||
**Step 4:** View the search results along with similarity scores. | ||
<div style="text-align: center;"> | ||
<figure> | ||
<img src="./pics/step4.jpg" alt="Description of Image" width="700"/> | ||
</figure> | ||
</div> | ||
|
||
## Code Structure | ||
|
||
```text | ||
./image_search_with_milvus | ||
├── app.py | ||
├── encoder.py | ||
├── insert.py | ||
├── milvus_utils.py | ||
├── requirements.txt | ||
``` | ||
|
||
- app.py: The main file to run application by Streamlit, where the user interface is defined and the image similarity search is performed. | ||
- encoder.py: The image encoder is able to convert image to image embeddings using a pretrained model. | ||
- insert.py: This script creates a Milvus collection and inserts images with corresponding embembeddings into the collection. | ||
- milvus_utils.py: Includes functions for interacting with the Milvus database, such as setting up Milvus client. |
98 changes: 98 additions & 0 deletions
98
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/app.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import os | ||
import streamlit as st | ||
from streamlit_cropper import st_cropper | ||
import streamlit_cropper | ||
from PIL import Image | ||
|
||
st.set_page_config(layout="wide") | ||
|
||
from encoder import FeatureExtractor | ||
from milvus_utils import get_milvus_client, get_search_results | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
COLLECTION_NAME = os.getenv("COLLECTION_NAME") | ||
MILVUS_ENDPOINT = os.getenv("MILVUS_ENDPOINT") | ||
MILVUS_TOKEN = os.getenv("MILVUS_TOKEN") | ||
MODEL_NAME = os.getenv("MODEL_NAME") | ||
|
||
|
||
def _recommended_box2(img: Image, aspect_ratio: tuple) -> dict: | ||
width, height = img.size | ||
return { | ||
"left": int(0), | ||
"top": int(0), | ||
"width": int(width - 2), | ||
"height": int(height - 2), | ||
} | ||
|
||
|
||
streamlit_cropper._recommended_box = _recommended_box2 | ||
|
||
|
||
# Get client and model ready | ||
milvus_client = get_milvus_client(uri=MILVUS_ENDPOINT, token=MILVUS_TOKEN) | ||
image_encoder = FeatureExtractor(MODEL_NAME) | ||
|
||
# Logo | ||
st.sidebar.image("./pics/Milvus_Logo_Official.png", width=200) | ||
|
||
# Title | ||
st.title("Image Similarity Search :frame_with_picture: ") | ||
|
||
query_image = "temp.jpg" | ||
cols = st.columns(5) | ||
|
||
uploaded_file = st.sidebar.file_uploader("Choose an image...", type="jpeg") | ||
|
||
if uploaded_file is not None: | ||
with open("temp.jpg", "wb") as f: | ||
f.write(uploaded_file.getbuffer()) | ||
# cropper | ||
# Get a cropped image from the frontend | ||
uploaded_img = Image.open(uploaded_file) | ||
width, height = uploaded_img.size | ||
|
||
new_width = 370 | ||
new_height = int((new_width / width) * height) | ||
uploaded_img = uploaded_img.resize((new_width, new_height)) | ||
|
||
st.sidebar.text( | ||
"Query Image", | ||
help="Edit the bounding box to change the ROI (Region of Interest).", | ||
) | ||
with st.sidebar.empty(): | ||
cropped_img = st_cropper( | ||
uploaded_img, | ||
box_color="#4fc4f9", | ||
realtime_update=True, | ||
aspect_ratio=(16, 9), | ||
) | ||
|
||
show_distance = st.sidebar.toggle("Show Distance") | ||
|
||
# top k value slider | ||
value = st.sidebar.slider("Select top k results shown", 10, 100, 20, step=1) | ||
|
||
@st.cache_resource | ||
def get_image_embedding(image: Image): | ||
return image_encoder(image) | ||
|
||
results = get_search_results( | ||
milvus_client=milvus_client, | ||
collection_name=COLLECTION_NAME, | ||
query_vector=get_image_embedding(cropped_img), | ||
output_fields=["filename"], | ||
) | ||
search_results = results[0] | ||
|
||
for i, info in enumerate(search_results): | ||
img_info = info["entity"] | ||
imgName = img_info["filename"] | ||
score = info["distance"] | ||
img = Image.open(imgName) | ||
cols[i % 5].image(img, use_column_width=True) | ||
if show_distance: | ||
cols[i % 5].write(f"Score: {score:.3f}") |
45 changes: 45 additions & 0 deletions
45
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import streamlit as st | ||
import torch | ||
import timm | ||
from timm.data import resolve_data_config | ||
from timm.data.transforms_factory import create_transform | ||
from PIL import Image | ||
from sklearn.preprocessing import normalize | ||
|
||
|
||
class FeatureExtractor: | ||
def __init__(self, modelname): | ||
# Load the pre-trained model | ||
self.model = timm.create_model( | ||
modelname, pretrained=True, num_classes=0, global_pool="avg" | ||
) | ||
self.model.eval() | ||
|
||
# Get the input size required by the model | ||
self.input_size = self.model.default_cfg["input_size"] | ||
|
||
config = resolve_data_config({}, model=modelname) | ||
# Get the preprocessing function provided by TIMM for the model | ||
self.preprocess = create_transform(**config) | ||
|
||
def __call__(self, image: Image): | ||
# Preprocess the input image | ||
input_image = image.convert("RGB") # Convert to RGB if needed | ||
input_image = self.preprocess(input_image) | ||
|
||
# Convert the image to a PyTorch tensor and add a batch dimension | ||
input_tensor = input_image.unsqueeze(0) | ||
|
||
# Perform inference | ||
with torch.no_grad(): | ||
output = self.model(input_tensor) | ||
|
||
# Extract the feature vector | ||
feature_vector = output.squeeze().numpy() | ||
|
||
return normalize(feature_vector.reshape(1, -1), norm="l2").flatten() | ||
|
||
|
||
# @st.cache_resource | ||
# def load_model(image_encoder): | ||
# return image_encoder.model |
48 changes: 48 additions & 0 deletions
48
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/insert.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import sys | ||
from glob import glob | ||
from PIL import Image | ||
from tqdm import tqdm | ||
|
||
from encoder import FeatureExtractor | ||
from milvus_utils import get_milvus_client, create_collection | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
COLLECTION_NAME = os.getenv("COLLECTION_NAME") | ||
MILVUS_ENDPOINT = os.getenv("MILVUS_ENDPOINT") | ||
MILVUS_TOKEN = os.getenv("MILVUS_TOKEN") | ||
MODEL_NAME = os.getenv("MODEL_NAME") | ||
MODEL_DIM = os.getenv("MODEL_DIM") | ||
|
||
data_dir = sys.argv[-1] | ||
image_encoder = FeatureExtractor(MODEL_NAME) | ||
milvus_client = get_milvus_client(uri=MILVUS_ENDPOINT, token=MILVUS_TOKEN) | ||
|
||
# Create collection | ||
create_collection( | ||
milvus_client=milvus_client, collection_name=COLLECTION_NAME, dim=MODEL_DIM | ||
) | ||
|
||
# Load images from directory and generate embeddings | ||
image_paths = glob(os.path.join(data_dir, "**/*.JPEG")) | ||
data = [] | ||
for i, filepath in enumerate(tqdm(image_paths, desc="Generating embeddings ...")): | ||
try: | ||
image = Image.open(filepath) | ||
image_embedding = image_encoder(image) | ||
data.append({"vector": image_embedding, "filename": filepath}) | ||
except Exception as e: | ||
print( | ||
f"Skipping file: {filepath} due to an error occurs during the embedding process:\n{e}" | ||
) | ||
continue | ||
|
||
# Insert data into Milvus | ||
mr = milvus_client.insert( | ||
collection_name=COLLECTION_NAME, | ||
data=data, | ||
) | ||
print("Total number of inserted entities/images:", mr["insert_count"]) |
35 changes: 35 additions & 0 deletions
35
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/milvus_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import streamlit as st | ||
from pymilvus import MilvusClient | ||
|
||
|
||
@st.cache_resource | ||
def get_milvus_client(uri: str, token: str = None) -> MilvusClient: | ||
return MilvusClient(uri=uri, token=token) | ||
|
||
|
||
def create_collection( | ||
milvus_client: MilvusClient, collection_name: str, dim: int, drop_old: bool = True | ||
): | ||
if milvus_client.has_collection(collection_name) and drop_old: | ||
milvus_client.drop_collection(collection_name) | ||
if milvus_client.has_collection(collection_name): | ||
raise RuntimeError( | ||
f"Collection {collection_name} already exists. Set drop_old=True to create a new one instead." | ||
) | ||
return milvus_client.create_collection( | ||
collection_name=collection_name, | ||
dimension=dim, | ||
metric_type="COSINE", | ||
consistency_level="Strong", | ||
auto_id=True, | ||
) | ||
|
||
|
||
def get_search_results(milvus_client, collection_name, query_vector, output_fields): | ||
search_res = milvus_client.search( | ||
collection_name=collection_name, | ||
data=[query_vector], | ||
search_params={"metric_type": "COSINE", "params": {}}, | ||
output_fields=output_fields, | ||
) | ||
return search_res |
Binary file added
BIN
+38 KB
...utorials/quickstart/apps/image_search_with_milvus/pics/Milvus_Logo_Official.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+215 KB
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/pics/step1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+620 KB
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/pics/step2_and_3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+616 KB
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/pics/step4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions
9
bootcamp/tutorials/quickstart/apps/image_search_with_milvus/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
streamlit | ||
streamlit-cropper | ||
torch | ||
timm | ||
Pillow | ||
scikit-learn | ||
pymilvus>=2.4.4 | ||
certifi | ||
requests |