Skip to content

Commit

Permalink
Add ImageBind classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Kanazawanaoaki committed Oct 15, 2023
1 parent 9d879cf commit e04d891
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 8 deletions.
16 changes: 13 additions & 3 deletions doc/jsk_perception/nodes/classification_node.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

![](images/clip.png)

The ROS node for Classification with CLIP.
The ROS node for Classification with CLIP or ImageBind.

## System Configuration
![](images/large_scale_vil_system.png)
Expand Down Expand Up @@ -63,19 +63,29 @@ make
Default categories used for subscribing image topic.

### Run inference container on another host or another terminal
Now you can use CLIP or ImageBind.

#### If you want to use CLIP.
In the remote GPU machine,
```shell
cd jsk_recognition/jsk_perception/docker
./run_jsk_vil_api clip --port (Your vacant port)
```

#### If you want to use ImageBind.
In the remote GPU machine,
```shell
cd jsk_recognition/jsk_perception/docker
./run_jsk_vil_api image-bind --port (Your vacant port)
```

In the ROS machine,
```shell
roslaunch jsk_perception classification.launch port:=(Your inference container port) host:=(Your inference container host) CLASSIFICATION_INPUT_IMAGE:=(Your image topic name) gui:=true
roslaunch jsk_perception classification.launch port:=(Your inference container port) host:=(Your inference container host) CLASSIFICATION_INPUT_IMAGE:=(Your image topic name) model:=(Your using model's name) gui:=true
```
### Run both inference container and ros node in single host
```
roslaunch jsk_perception classification.launch run_api:=true CLASSIFICATION_INPUT_IMAGE:=(Your image topic name) gui:=true
roslaunch jsk_perception classification.launch run_api:=true CLASSIFICATION_INPUT_IMAGE:=(Your image topic name) model:=(Your using model's name) gui:=true
```
9 changes: 7 additions & 2 deletions jsk_perception/docker/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
# api directories
OFAPROJECT = ofa
CLIPPROJECT = clip
IMAGEBINDPROJECT = image-bind
# image names
OFAIMAGE = jsk-ofa-server
CLIPIMAGE = jsk-clip-server
IMAGEBINDIMAGE = jsk-image-bind-server
# commands
BUILDIMAGE = docker build
REMOVEIMAGE = docker rmi
Expand All @@ -23,7 +25,7 @@ PARAMURLS = parameter_urls.txt
# OFA parameters
OFAPARAMFILES = $(foreach param, $(OFAPARAMS), $(PARAMDIR)/$(param))

all: ofa clip
all: ofa clip image-bind

# TODO check command wget exists, nvidia-driver version

Expand All @@ -41,11 +43,14 @@ ofa: $(PARAMDIR)/.download
clip: $(PARAMDIR)/.download
$(BUILDIMAGE) $(CLIPPROJECT) -t $(CLIPIMAGE) -f $(CLIPPROJECT)/Dockerfile

image-bind: $(PARAMDIR)/.download
$(BUILDIMAGE) $(IMAGEBINDPROJECT) -t $(IMAGEBINDIMAGE) -f $(IMAGEBINDPROJECT)/Dockerfile

# TODO add clip, glip
clean:
@$(REMOVEIMAGE) $(OFAIMAGE)

wipe: clean
rm -fr $(PARAMDIR)

.PHONY: clean wipe ofa clip
.PHONY: clean wipe ofa clip image-bind
32 changes: 32 additions & 0 deletions jsk_perception/docker/image-bind/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
FROM pytorch/pytorch:1.9.1-cuda11.1-cudnn8-devel
ARG DEBIAN_FRONTEND=noninteractive
RUN apt -o Acquire::AllowInsecureRepositories=true update \
&& apt-get install -y \
curl \
git \
libopencv-dev \
wget \
emacs \
python3.8 \
python3-dev \
libproj-dev \
proj-data \
proj-bin \
libgeos-dev \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
ENV CUDA_HOME /usr/local/cuda
ENV TORCH_CUDA_ARCH_LIST 8.0+PTX
RUN git clone https://github.com/Kanazawanaoaki/ImageBind.git -b update-data-load
RUN echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc
RUN echo 'TORCH_CUDA_ARCH_LIST=8.0+PTX' >> ~/.bashrc
RUN pip install flask opencv-python \
&& pip install soundfile \
&& pip install --upgrade pip setuptools wheel \
&& pip install cartopy==0.19.0.post1
RUN cd ImageBind \
&& git pull origin update-data-load \
&& pip install -r requirements.txt \
&& pip install -e .
COPY server.py /workspace/ImageBind
ENTRYPOINT cd /workspace/ImageBind && python server.py
166 changes: 166 additions & 0 deletions jsk_perception/docker/image-bind/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
import io

import cv2
import numpy as np
from PIL import Image as PLImage
import torch

# web server
from flask import Flask, request, Response
import json
import base64

def apply_half(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.half)
return t

class Inference:
def __init__(self, modal, gpu_id=None):
self.gpu_id = gpu_id
self.device = "cuda" if torch.cuda.is_available() else "cpu"

self.modal_name = modal

self.model = imagebind_model.imagebind_huge(pretrained=True)
self.model.eval()
self.model.to(self.device)

self.video_sample_rate=16000

def convert_to_string(self, input_list):
output_string = ""
for item in input_list:
output_string += item + " . "
return output_string.strip()

def infer(self, msg, texts):
text_inputs = texts

if self.modal_name == "image":
# get cv2 image
# image = cv2.resize(img, dsize=(640, 480)) # NOTE forcely
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.cvtColor(msg, cv2.COLOR_BGR2RGB)
image = PLImage.fromarray(image)

image_input = [image]

inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_inputs, self.device),
ModalityType.VISION: data.load_and_transform_vision_data(None, self.device, image_input),
}
modal_data_type = ModalityType.VISION

elif self.modal_name == "video":
import decord
decord.bridge.set_bridge("torch")
video_io = io.BytesIO(msg)
video = EncodedVideoDecord(file=video_io,
video_name="current_video_data",
decode_video=True,
decode_audio=False,
**{"sample_rate": self.video_sample_rate},
)

inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_inputs, self.device),
ModalityType.VISION: data.load_and_transform_video_data(None, self.device, videos=[video]),
}
modal_data_type = ModalityType.VISION

elif self.modal_name == "audio":
waveform = msg["waveform"]
sr = msg["sr"]
waveform_np = np.frombuffer(waveform, dtype=np.float32)
waveform_torch = torch.tensor(waveform_np.reshape(1, -1))

inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_inputs, self.device),
ModalityType.AUDIO: data.load_and_transform_audio_data(None, self.device, audios=[{"waveform": waveform_torch, "sr": sr}]),
}
modal_data_type = ModalityType.AUDIO

# Calculate features
with torch.no_grad():
embeddings = self.model(inputs)

similarity = np.average((embeddings[modal_data_type] @ embeddings[ModalityType.TEXT].T).tolist(), axis=0)
probability = torch.softmax(embeddings[modal_data_type] @ embeddings[ModalityType.TEXT].T, dim=-1)

values, indices = probability[0].topk(len(texts))
results = {}
for value, index in zip(values, indices):
results[texts[index]] = (value.item(), float(similarity[index]))
return results

# run
if __name__ == "__main__":
app = Flask(__name__)

image_infer = Inference("image")
video_infer = Inference("video")
audio_infer = Inference("audio")

try:
@app.route("/inference", methods=['POST'])
def image_request():
data = request.data.decode("utf-8")
data_json = json.loads(data)
# process image
image_b = data_json['image']
image_dec = base64.b64decode(image_b)
data_np = np.fromstring(image_dec, dtype='uint8')
img = cv2.imdecode(data_np, 1)
# get text
texts = data_json['queries']
infer_results = image_infer.infer(img, texts)
results = []
for q in infer_results:
results.append({"question": q, "probability": infer_results[q][0], "similarity": infer_results[q][1]})
return Response(response=json.dumps({"results": results}), status=200)
except NameError:
print("Skipping create inference app")

try:
@app.route("/video_class", methods=['POST'])
def video_request():
data = request.data.decode("utf-8")
data_json = json.loads(data)
# process image
video_b = data_json['video']
video_dec = base64.b64decode(video_b)
# get text
texts = data_json['queries']
infer_results = video_infer.infer(video_dec, texts)
results = []
for q in infer_results:
results.append({"question": q, "probability": infer_results[q][0], "similarity": infer_results[q][1]})
return Response(response=json.dumps({"results": results}), status=200)
except NameError:
print("Skipping create video_class app")

try:
@app.route("/audio_class", methods=['POST'])
def audio_request():
data = request.data.decode("utf-8")
data_json = json.loads(data)
# process image
audio_b = data_json['audio']
sr = data_json['sr']
audio_dec = base64.b64decode(audio_b)
# get text
texts = data_json['queries']
infer_results = audio_infer.infer({"waveform": audio_dec, "sr": sr}, texts)
results = []
for q in infer_results:
results.append({"question": q, "probability": infer_results[q][0], "similarity": infer_results[q][1]})
return Response(response=json.dumps({"results": results}), status=200)
except NameError:
print("Skipping create audio_class app")

app.run("0.0.0.0", 8080, threaded=True)
3 changes: 2 additions & 1 deletion jsk_perception/docker/run_jsk_vil_api
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import subprocess
import sys

CONTAINERS = {"ofa": "jsk-ofa-server",
"clip": "jsk-clip-server"}
"clip": "jsk-clip-server",
"image-bind": "jsk-image-bind-server"}
OFA_MODEL_SCALES = ["base", "large", "huge"]

parser = argparse.ArgumentParser(description="JSK Vision and Language API runner")
Expand Down
4 changes: 3 additions & 1 deletion jsk_perception/launch/classification.launch
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
<arg name="port" default="8888" />
<arg name="gui" default="false" />
<arg name="run_api" default="false" />
<arg name="model" default="clip" />
<arg name="CLASSIFICATION_INPUT_IMAGE" default="image" />

<node name="classification_api" pkg="jsk_perception" type="run_jsk_vil_api" output="log"
args="clip -p $(arg port)" if="$(arg run_api)" />
args="(arg model) -p $(arg port)" if="$(arg run_api)" />

<node name="classification" pkg="jsk_perception" type="classification_node.py" output="screen">
<remap from="~image" to="$(arg CLASSIFICATION_INPUT_IMAGE)" />
<rosparam subst_value="true">
host: $(arg host)
port: $(arg port)
model: $(arg model)
</rosparam>
</node>

Expand Down
3 changes: 2 additions & 1 deletion jsk_perception/src/jsk_perception/vil_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self):
ClassificationTaskFeedback,
ClassificationTaskResult,
"inference")
self.model_name = rospy.get_param("~model", default="clip")

def topic_cb(self, data):
if not self.config: rospy.logwarn("No queries"); return
Expand Down Expand Up @@ -185,7 +186,7 @@ def inference(self, img_msg, queries):
msg.label_names = labels
msg.label_proba = similarities # cosine similarities
msg.probabilities = probabilities # sum(probabilities) is 1
msg.classifier = 'clip'
msg.classifier = self.model_name
msg.target_names = queries
return msg

Expand Down

0 comments on commit e04d891

Please sign in to comment.