Skip to content

Commit

Permalink
Merge pull request #1376 from reina-w/image-demo
Browse files Browse the repository at this point in the history
Image demo
  • Loading branch information
zc277584121 authored Jul 12, 2024
2 parents e321ee0 + df5ed10 commit d1dd8ca
Show file tree
Hide file tree
Showing 11 changed files with 358 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
MILVUS_ENDPOINT=./milvus_demo.db
# MILVUS_TOKEN=************

COLLECTION_NAME=my_image_collection

MODEL_NAME=resnet34
MODEL_DIM=512
116 changes: 116 additions & 0 deletions bootcamp/tutorials/quickstart/apps/image_search_with_milvus/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Image Similarity Search with Milvus 🖼️

This demo implements an image similarity search application using Streamlit, Milvus, and a pre-trained ResNet model. Users can upload an image, crop it to focus on the region of interest, and search for similar images from a pre-built database.

## Features
- Upload and crop images to define the region of interest.
- Extract features using a pre-trained ResNet model.
- Search for similar images using Milvus for efficient similarity search.
- Display search results along with similarity scores.

## Quick Deploy

Follow these steps to quickly deploy the application locally:

### Preparation

> Prerequisites: Python 3.8 or higher
**1. Download Codes**
```bash
$ git clone <https://github.com/milvus-io/bootcamp.git>
$ cd bootcamp/bootcamp/tutorials/quickstart/app/image_search_with_milvus
```

**2. Set Environment**

- Install dependencies

```bash
$ pip install -r requirements.txt
```

- Set environment variables

Modify the environment file [.env](./.env) to change environment variables:

```
MILVUS_ENDPOINT=./milvus_demo.db
# MILVUS_TOKEN=************

COLLECTION_NAME=my_image_collection
```

- `MILVUS_ENDPOINT`: The URI to connect Milvus or Zilliz Cloud service. By default, we use a local file "./milvus_demo.db" for convenience, as it automatically utilizes [Milvus Lite](https://milvus.io/docs/milvus_lite.md) to store all data at local.
> - If you have large scale of data, you can set up a more performant Milvus server on docker or kubernetes. In this setup, please use the server uri, e.g. http://localhost:19530, as your uri.
>
> - If you want to use Zilliz Cloud, the fully managed cloud service for Milvus, adjust the uri and token, which correspond to the Public Endpoint and Api key in Zilliz Cloud.
- `MILVUS_TOKEN`: This is optional. Uncomment this line to enter your password if authentication is required by your Milvus or Zilliz Cloud service.
- `COLLECTION_NAME`: The collection name in Milvus database, defaults to "my_image_collection".
- `MODEL_NAME`: The name of pretrained image embedding model, defaults to "resnet34".
- `MODEL_DIM`: The embedding dimension, which should change according to the MODEL_NAME.

**3. Prepare Data**

We are using [a subset of ImageNet](https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip) for this demo, which includes approximately 200 categories with images of animals, objects, buildings, and more.<br>

```bash
$ wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
$ unzip -q reverse_image_search.zip -d reverse_image_search
```

Create a collection named as the environment variable `COLLECTION_NAME` and load data from "reverse_image_search/train" to get the knowledge ready by running the [insert.py](./insert.py). Please note we will use the JPEG images only from the "reverse_image_search/train" to build the database. There will be about 1000 images.

```bash
$ python insert.py reverse_image_search/train
```

> **Note:** If the collection exists in the Milvus database, then it will be dropped first to create a new one.

### Start Service

Run the Streamlit application:

```bash
$ streamlit run app.py
```

### Example Usage

**Step 1:** Choose an image file to upload (JPEG format).
<div style="text-align: center;">
<figure>
<img src="./pics/step1.png" alt="Description of Image" width="700"/>
</figure>
</div>

**Step 2:** Crop the image to focus on the region of interest.
**Step 3:** Set the desired number of top-k results to display using the slider.
<div style="text-align: center;">
<figure>
<img src="./pics/step2_and_3.jpg" alt="Description of Image" width="700"/>
</figure>
</div>

**Step 4:** View the search results along with similarity scores.
<div style="text-align: center;">
<figure>
<img src="./pics/step4.jpg" alt="Description of Image" width="700"/>
</figure>
</div>

## Code Structure

```text
./image_search_with_milvus
├── app.py
├── encoder.py
├── insert.py
├── milvus_utils.py
├── requirements.txt
```

- app.py: The main file to run application by Streamlit, where the user interface is defined and the image similarity search is performed.
- encoder.py: The image encoder is able to convert image to image embeddings using a pretrained model.
- insert.py: This script creates a Milvus collection and inserts images with corresponding embembeddings into the collection.
- milvus_utils.py: Includes functions for interacting with the Milvus database, such as setting up Milvus client.
98 changes: 98 additions & 0 deletions bootcamp/tutorials/quickstart/apps/image_search_with_milvus/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import streamlit as st
from streamlit_cropper import st_cropper
import streamlit_cropper
from PIL import Image

st.set_page_config(layout="wide")

from encoder import FeatureExtractor
from milvus_utils import get_milvus_client, get_search_results

from dotenv import load_dotenv

load_dotenv()

COLLECTION_NAME = os.getenv("COLLECTION_NAME")
MILVUS_ENDPOINT = os.getenv("MILVUS_ENDPOINT")
MILVUS_TOKEN = os.getenv("MILVUS_TOKEN")
MODEL_NAME = os.getenv("MODEL_NAME")


def _recommended_box2(img: Image, aspect_ratio: tuple) -> dict:
width, height = img.size
return {
"left": int(0),
"top": int(0),
"width": int(width - 2),
"height": int(height - 2),
}


streamlit_cropper._recommended_box = _recommended_box2


# Get client and model ready
milvus_client = get_milvus_client(uri=MILVUS_ENDPOINT, token=MILVUS_TOKEN)
image_encoder = FeatureExtractor(MODEL_NAME)

# Logo
st.sidebar.image("./pics/Milvus_Logo_Official.png", width=200)

# Title
st.title("Image Similarity Search :frame_with_picture: ")

query_image = "temp.jpg"
cols = st.columns(5)

uploaded_file = st.sidebar.file_uploader("Choose an image...", type="jpeg")

if uploaded_file is not None:
with open("temp.jpg", "wb") as f:
f.write(uploaded_file.getbuffer())
# cropper
# Get a cropped image from the frontend
uploaded_img = Image.open(uploaded_file)
width, height = uploaded_img.size

new_width = 370
new_height = int((new_width / width) * height)
uploaded_img = uploaded_img.resize((new_width, new_height))

st.sidebar.text(
"Query Image",
help="Edit the bounding box to change the ROI (Region of Interest).",
)
with st.sidebar.empty():
cropped_img = st_cropper(
uploaded_img,
box_color="#4fc4f9",
realtime_update=True,
aspect_ratio=(16, 9),
)

show_distance = st.sidebar.toggle("Show Distance")

# top k value slider
value = st.sidebar.slider("Select top k results shown", 10, 100, 20, step=1)

@st.cache_resource
def get_image_embedding(image: Image):
return image_encoder(image)

results = get_search_results(
milvus_client=milvus_client,
collection_name=COLLECTION_NAME,
query_vector=get_image_embedding(cropped_img),
output_fields=["filename"],
)
search_results = results[0]

for i, info in enumerate(search_results):
img_info = info["entity"]
imgName = img_info["filename"]
score = info["distance"]
img = Image.open(imgName)
cols[i % 5].image(img, use_column_width=True)
if show_distance:
cols[i % 5].write(f"Score: {score:.3f}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import streamlit as st
import torch
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image
from sklearn.preprocessing import normalize


class FeatureExtractor:
def __init__(self, modelname):
# Load the pre-trained model
self.model = timm.create_model(
modelname, pretrained=True, num_classes=0, global_pool="avg"
)
self.model.eval()

# Get the input size required by the model
self.input_size = self.model.default_cfg["input_size"]

config = resolve_data_config({}, model=modelname)
# Get the preprocessing function provided by TIMM for the model
self.preprocess = create_transform(**config)

def __call__(self, image: Image):
# Preprocess the input image
input_image = image.convert("RGB") # Convert to RGB if needed
input_image = self.preprocess(input_image)

# Convert the image to a PyTorch tensor and add a batch dimension
input_tensor = input_image.unsqueeze(0)

# Perform inference
with torch.no_grad():
output = self.model(input_tensor)

# Extract the feature vector
feature_vector = output.squeeze().numpy()

return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()


# @st.cache_resource
# def load_model(image_encoder):
# return image_encoder.model
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import sys
from glob import glob
from PIL import Image
from tqdm import tqdm

from encoder import FeatureExtractor
from milvus_utils import get_milvus_client, create_collection

from dotenv import load_dotenv

load_dotenv()

COLLECTION_NAME = os.getenv("COLLECTION_NAME")
MILVUS_ENDPOINT = os.getenv("MILVUS_ENDPOINT")
MILVUS_TOKEN = os.getenv("MILVUS_TOKEN")
MODEL_NAME = os.getenv("MODEL_NAME")
MODEL_DIM = os.getenv("MODEL_DIM")

data_dir = sys.argv[-1]
image_encoder = FeatureExtractor(MODEL_NAME)
milvus_client = get_milvus_client(uri=MILVUS_ENDPOINT, token=MILVUS_TOKEN)

# Create collection
create_collection(
milvus_client=milvus_client, collection_name=COLLECTION_NAME, dim=MODEL_DIM
)

# Load images from directory and generate embeddings
image_paths = glob(os.path.join(data_dir, "**/*.JPEG"))
data = []
for i, filepath in enumerate(tqdm(image_paths, desc="Generating embeddings ...")):
try:
image = Image.open(filepath)
image_embedding = image_encoder(image)
data.append({"vector": image_embedding, "filename": filepath})
except Exception as e:
print(
f"Skipping file: {filepath} due to an error occurs during the embedding process:\n{e}"
)
continue

# Insert data into Milvus
mr = milvus_client.insert(
collection_name=COLLECTION_NAME,
data=data,
)
print("Total number of inserted entities/images:", mr["insert_count"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import streamlit as st
from pymilvus import MilvusClient


@st.cache_resource
def get_milvus_client(uri: str, token: str = None) -> MilvusClient:
return MilvusClient(uri=uri, token=token)


def create_collection(
milvus_client: MilvusClient, collection_name: str, dim: int, drop_old: bool = True
):
if milvus_client.has_collection(collection_name) and drop_old:
milvus_client.drop_collection(collection_name)
if milvus_client.has_collection(collection_name):
raise RuntimeError(
f"Collection {collection_name} already exists. Set drop_old=True to create a new one instead."
)
return milvus_client.create_collection(
collection_name=collection_name,
dimension=dim,
metric_type="COSINE",
consistency_level="Strong",
auto_id=True,
)


def get_search_results(milvus_client, collection_name, query_vector, output_fields):
search_res = milvus_client.search(
collection_name=collection_name,
data=[query_vector],
search_params={"metric_type": "COSINE", "params": {}},
output_fields=output_fields,
)
return search_res
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
streamlit
streamlit-cropper
torch
timm
Pillow
scikit-learn
pymilvus>=2.4.4
certifi
requests

0 comments on commit d1dd8ca

Please sign in to comment.