Skip to content

Commit 6a781ea

Browse files
chiproLyken17
andauthored
Add VILA API server which is compatible with OpenAI SDK (NVlabs#133)
Co-authored-by: Ligeng Zhu <[email protected]>
1 parent 3710e28 commit 6a781ea

File tree

6 files changed

+404
-6
lines changed

6 files changed

+404
-6
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,5 @@ ckpts*
4949

5050
playground
5151
*/visualization/*
52+
.env
53+
server.log

Dockerfile

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
FROM nvcr.io/nvidia/pytorch:24.06-py3
2+
3+
WORKDIR /app
4+
5+
RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \
6+
&& sh ~/miniconda.sh -b -p /opt/conda \
7+
&& rm ~/miniconda.sh
8+
9+
ENV PATH /opt/conda/bin:$PATH
10+
COPY pyproject.toml pyproject.toml
11+
COPY llava llava
12+
13+
COPY environment_setup.sh environment_setup.sh
14+
RUN bash environment_setup.sh vila
15+
16+
17+
COPY server.py server.py
18+
CMD ["conda", "run", "-n", "vila", "--no-capture-output", "python", "-u", "-W", "ignore", "server.py"]

README.md

+67-6
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ VILA is a visual language model (VLM) pretrained with interleaved image-text dat
5757

5858
| $~~~~~~$ | Precision | A100 | 4090 | Orin |
5959
| ---------------------- | --------- | ----- | ----- | ---- |
60-
| VILA1.5-3B | fp16 | 104.6 | 137.6 | 25.4 |
61-
| VILA1.5-3B-AWQ | int4 | 182.8 | 215.5 | 42.5 |
62-
| VILA1.5-3B-S2 | fp16 | 104.3 | 137.2 | 24.6 |
63-
| VILA1.5-3B-S2-AWQ | int4 | 180.2 | 219.3 | 40.1 |
60+
| VILA1.5-3B | fp16 | 104.6 | 137.6 | 25.4 |
61+
| VILA1.5-3B-AWQ | int4 | 182.8 | 215.5 | 42.5 |
62+
| VILA1.5-3B-S2 | fp16 | 104.3 | 137.2 | 24.6 |
63+
| VILA1.5-3B-S2-AWQ | int4 | 180.2 | 219.3 | 40.1 |
6464
| Llama-3-VILA1.5-8B | fp16 | 74.9 | 57.4 | 10.2 |
6565
| Llama-3-VILA1.5-8B-AWQ | int4 | 168.9 | 150.2 | 28.7 |
6666
| VILA1.5-13B | fp16 | 50.9 | OOM | 6.1 |
6767
| VILA1.5-13B-AWQ | int4 | 115.9 | 105.7 | 20.6 |
68-
| VILA1.5-40B | fp16 | OOM | OOM | -- |
69-
| VILA1.5-40B-AWQ | int4 | 57.0 | OOM | -- |
68+
| VILA1.5-40B | fp16 | OOM | OOM | -- |
69+
| VILA1.5-40B-AWQ | int4 | 57.0 | OOM | -- |
7070

7171
<sup>NOTE: Measured using the [TinyChat](https://github.com/mit-han-lab/llm-awq/tinychat) backend at batch size = 1.</sup>
7272

@@ -232,6 +232,67 @@ We support AWQ-quantized 4bit VILA on GPU platforms via [TinyChat](https://githu
232232

233233
We further support our AWQ-quantized 4bit VILA models on various CPU platforms with both x86 and ARM architectures with our [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine). We also provide a detailed [tutorial](https://github.com/mit-han-lab/TinyChatEngine/tree/main?tab=readme-ov-file#deploy-vision-language-model-vlm-chatbot-with-tinychatengine) to help the users deploy VILA on different CPUs.
234234

235+
### Running VILA API server
236+
237+
A simple API server has been provided to serve VILA models. The server is built on top of [FastAPI](https://fastapi.tiangolo.com/) and [Huggingface Transformers](https://huggingface.co/transformers/). The server can be run with the following command:
238+
239+
#### With CLI
240+
241+
```bash
242+
python -W ignore server.py \
243+
--port 8000 \
244+
--model-path Efficient-Large-Model/VILA1.5-3B \
245+
--conv-mode vicuna_v1
246+
```
247+
248+
#### With Docker
249+
250+
```bash
251+
docker build -t vila-server:latest .
252+
docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
253+
-v ./hub:/root/.cache/huggingface/hub \
254+
-it --rm -p 8000:8000 \
255+
-e VILA_MODEL_PATH=Efficient-Large-Model/VILA1.5-3B \
256+
-e VILA_CONV_MODE=vicuna_v1 \
257+
vila-server:latest
258+
```
259+
260+
Then you can call the endpoint with the OpenAI SDK as follows:
261+
262+
```python
263+
from openai import OpenAI
264+
265+
client = OpenAI(
266+
base_url="http://localhost:8000",
267+
api_key="fake-key",
268+
)
269+
response = client.chat.completions.create(
270+
messages=[
271+
{
272+
"role": "user",
273+
"content": [
274+
{"type": "text", "text": "What’s in this image?"},
275+
{
276+
"type": "image_url",
277+
"image_url": {
278+
"url": "https://blog.logomyway.com/wp-content/uploads/2022/01/NVIDIA-logo.jpg",
279+
# Or you can pass in a base64 encoded image
280+
# "url": "data:image/png;base64,<base64_encoded_image>",
281+
},
282+
},
283+
],
284+
}
285+
],
286+
max_tokens=300,
287+
model="VILA1.5-3B",
288+
# You can pass in extra parameters as follows
289+
extra_body={"num_beams": 1, "use_cache": False},
290+
)
291+
print(response.choices[0].message.content)
292+
```
293+
294+
<sup>NOTE: This API server is intended for evaluation purposes only and has not been optimized for production use. It has only been tested on A100 and H100 GPUs.</sup>
295+
235296
## Checkpoints
236297

237298
We release [VILA1.5-3B](https://hf.co/Efficient-Large-Model/VILA1.5-3b), [VILA1.5-3B-S2](https://hf.co/Efficient-Large-Model/VILA1.5-3b-s2), [Llama-3-VILA1.5-8B](https://hf.co/Efficient-Large-Model/Llama-3-VILA1.5-8b), [VILA1.5-13B](https://hf.co/Efficient-Large-Model/VILA1.5-13b), [VILA1.5-40B](https://hf.co/Efficient-Large-Model/VILA1.5-40b) and the 4-bit [AWQ](https://arxiv.org/abs/2306.00978)-quantized models [VILA1.5-3B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-3b-AWQ), [VILA1.5-3B-S2-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-3b-s2-AWQ), [Llama-3-VILA1.5-8B-AWQ](https://hf.co/Efficient-Large-Model/Llama-3-VILA1.5-8b-AWQ), [VILA1.5-13B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-13b-AWQ), [VILA1.5-40B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-40b-AWQ).

server.py

+261
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import argparse
2+
import base64
3+
import os
4+
import re
5+
import time
6+
import uuid
7+
from contextlib import asynccontextmanager
8+
from io import BytesIO
9+
from typing import List, Literal, Optional, Union, get_args
10+
11+
import requests
12+
import torch
13+
import uvicorn
14+
from fastapi import FastAPI
15+
from fastapi.responses import JSONResponse
16+
from PIL import Image as PILImage
17+
from PIL.Image import Image
18+
from pydantic import BaseModel
19+
20+
from llava.constants import (
21+
DEFAULT_IM_END_TOKEN,
22+
DEFAULT_IM_START_TOKEN,
23+
DEFAULT_IMAGE_TOKEN,
24+
IMAGE_PLACEHOLDER,
25+
IMAGE_TOKEN_INDEX,
26+
)
27+
from llava.conversation import SeparatorStyle, conv_templates
28+
from llava.mm_utils import (
29+
KeywordsStoppingCriteria,
30+
get_model_name_from_path,
31+
process_images,
32+
tokenizer_image_token,
33+
)
34+
from llava.model.builder import load_pretrained_model
35+
from llava.utils import disable_torch_init
36+
37+
38+
class TextContent(BaseModel):
39+
type: Literal["text"]
40+
text: str
41+
42+
43+
class ImageURL(BaseModel):
44+
url: str
45+
46+
47+
class ImageContent(BaseModel):
48+
type: Literal["image_url"]
49+
image_url: ImageURL
50+
51+
52+
IMAGE_CONTENT_BASE64_REGEX = re.compile(r"^data:image/(png|jpe?g);base64,(.*)$")
53+
54+
55+
class ChatMessage(BaseModel):
56+
role: Literal["user", "assistant"]
57+
content: Union[str, List[Union[TextContent, ImageContent]]]
58+
59+
60+
class ChatCompletionRequest(BaseModel):
61+
model: Literal[
62+
"VILA1.5-3B",
63+
"VILA1.5-3B-AWQ",
64+
"VILA1.5-3B-S2",
65+
"VILA1.5-3B-S2-AWQ",
66+
"Llama-3-VILA1.5-8B",
67+
"Llama-3-VILA1.5-8B-AWQ",
68+
"VILA1.5-13B",
69+
"VILA1.5-13B-AWQ",
70+
"VILA1.5-40B",
71+
"VILA1.5-40B-AWQ",
72+
]
73+
messages: List[ChatMessage]
74+
max_tokens: Optional[int] = 512
75+
top_p: Optional[float] = 0.9
76+
temperature: Optional[float] = 0.2
77+
stream: Optional[bool] = False
78+
use_cache: Optional[bool] = True
79+
num_beams: Optional[int] = 1
80+
81+
model = None
82+
model_name = None
83+
tokenizer = None
84+
image_processor = None
85+
context_len = None
86+
87+
88+
def load_image(image_url: str) -> Image:
89+
if image_url.startswith("http") or image_url.startswith("https"):
90+
response = requests.get(image_url)
91+
image = PILImage.open(BytesIO(response.content)).convert("RGB")
92+
else:
93+
match_results = IMAGE_CONTENT_BASE64_REGEX.match(image_url)
94+
if match_results is None:
95+
raise ValueError(f"Invalid image url: {image_url}")
96+
image_base64 = match_results.groups()[1]
97+
image = PILImage.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
98+
return image
99+
100+
101+
def get_literal_values(cls, field_name: str):
102+
field_type = cls.__annotations__.get(field_name)
103+
if field_type is None:
104+
raise ValueError(f"{field_name} is not a valid field name")
105+
if hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
106+
return get_args(field_type)
107+
raise ValueError(f"{field_name} is not a Literal type")
108+
109+
110+
VILA_MODELS = get_literal_values(ChatCompletionRequest, "model")
111+
112+
113+
def normalize_image_tags(qs: str) -> str:
114+
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
115+
if IMAGE_PLACEHOLDER in qs:
116+
if model.config.mm_use_im_start_end:
117+
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
118+
else:
119+
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
120+
121+
if DEFAULT_IMAGE_TOKEN not in qs:
122+
raise ValueError("No image was found in input messages.")
123+
return qs
124+
125+
126+
@asynccontextmanager
127+
async def lifespan(app: FastAPI):
128+
global model, model_name, tokenizer, image_processor, context_len
129+
disable_torch_init()
130+
model_path = app.args.model_path
131+
model_name = get_model_name_from_path(model_path)
132+
tokenizer, model, image_processor, context_len = load_pretrained_model(
133+
model_path, model_name, None
134+
)
135+
print(f"Model {model_name} loaded successfully. Context length: {context_len}")
136+
yield
137+
138+
139+
app = FastAPI(lifespan=lifespan)
140+
141+
142+
# Load model upon startup
143+
@app.post("/chat/completions")
144+
async def chat_completions(request: ChatCompletionRequest):
145+
try:
146+
global model, tokenizer, image_processor, context_len
147+
148+
if request.model != model_name:
149+
raise ValueError(
150+
f"The endpoint is configured to use the model {model_name}, "
151+
f"but the request model is {request.model}"
152+
)
153+
max_tokens = request.max_tokens
154+
temperature = request.temperature
155+
top_p = request.top_p
156+
use_cache = request.use_cache
157+
num_beams = request.num_beams
158+
159+
messages = request.messages
160+
conv_mode = app.args.conv_mode
161+
162+
images = []
163+
164+
conv = conv_templates[conv_mode].copy()
165+
user_role = conv.roles[0]
166+
assistant_role = conv.roles[1]
167+
168+
for message in messages:
169+
if message.role == "user":
170+
prompt = ""
171+
172+
if isinstance(message.content, str):
173+
prompt += message.content
174+
if isinstance(message.content, list):
175+
for content in message.content:
176+
if content.type == "text":
177+
prompt += content.text
178+
if content.type == "image_url":
179+
image = load_image(content.image_url.url)
180+
images.append(image)
181+
prompt += IMAGE_PLACEHOLDER
182+
normalized_prompt = normalize_image_tags(prompt)
183+
conv.append_message(user_role, normalized_prompt)
184+
if message.role == "assistant":
185+
prompt = message.content
186+
conv.append_message(assistant_role, prompt)
187+
188+
prompt_text = conv.get_prompt()
189+
print("Prompt input: ", prompt_text)
190+
191+
images_tensor = process_images(images, image_processor, model.config).to(
192+
model.device, dtype=torch.float16
193+
)
194+
input_ids = (
195+
tokenizer_image_token(
196+
prompt_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
197+
)
198+
.unsqueeze(0)
199+
.to(model.device)
200+
)
201+
202+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
203+
keywords = [stop_str]
204+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
205+
206+
with torch.inference_mode():
207+
output_ids = model.generate(
208+
input_ids,
209+
images=[
210+
images_tensor,
211+
],
212+
do_sample=True if temperature > 0 else False,
213+
temperature=temperature,
214+
top_p=top_p,
215+
num_beams=num_beams,
216+
max_new_tokens=max_tokens,
217+
use_cache=use_cache,
218+
stopping_criteria=[stopping_criteria],
219+
)
220+
221+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
222+
outputs = outputs.strip()
223+
if outputs.endswith(stop_str):
224+
outputs = outputs[: -len(stop_str)]
225+
outputs = outputs.strip()
226+
print("\nAssistant: ", outputs)
227+
228+
resp_content = [TextContent(type="text", text=outputs)]
229+
return {
230+
"id": uuid.uuid4().hex,
231+
"object": "chat.completion",
232+
"created": time.time(),
233+
"model": request.model,
234+
"choices": [
235+
{"message": ChatMessage(role="assistant", content=resp_content)}
236+
],
237+
}
238+
except Exception as e:
239+
return JSONResponse(
240+
status_code=500,
241+
content={"error": str(e)},
242+
)
243+
244+
245+
if __name__ == "__main__":
246+
247+
host = os.getenv("VILA_HOST", "0.0.0.0")
248+
port = os.getenv("VILA_PORT", 8000)
249+
model_path = os.getenv("VILA_MODEL_PATH", "Efficient-Large-Model/VILA1.5-3B")
250+
conv_mode = os.getenv("VILA_CONV_MODE", "vicuna_v1")
251+
workers = os.getenv("VILA_WORKERS", 1)
252+
253+
parser = argparse.ArgumentParser()
254+
parser.add_argument("--host", type=str, default=host)
255+
parser.add_argument("--port", type=int, default=port)
256+
parser.add_argument("--model-path", type=str, default=model_path)
257+
parser.add_argument("--conv-mode", type=str, default=conv_mode)
258+
parser.add_argument("--workers", type=int, default=workers)
259+
app.args = parser.parse_args()
260+
261+
uvicorn.run(app, host=host, port=port, workers=workers)

0 commit comments

Comments
 (0)