-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbeam_inference_server.py
55 lines (45 loc) · 1.32 KB
/
beam_inference_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
HF_HOME_DIR = './checkpoints'
CACHE_DIR = './checkpoints/cache'
os.environ['HF_HOME'] = HF_HOME_DIR
from beam import Image, endpoint, Volume
from transformers import pipeline
from PIL import Image as PILImage
from io import BytesIO
import base64
# This function runs once when the container first starts
def load_models():
checkpoint = "google/owlv2-base-patch16-ensemble"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device='cuda:0', cache_dir=CACHE_DIR)
return detector
@endpoint(
name="owlv2",
cpu=1,
memory="8Gi",
gpu=["T4", "RTX4090"],
image=Image(
python_version="python3.10",
python_packages=[
"pillow==11.0.0",
"scikit-learn==1.5.2",
"scikit-image==0.24.0",
"transformers==4.45.2",
"torch==2.5.0",
],
),
volumes=[
# checkpoints is used to save fine-tuned models
Volume(name="owlv2-checkpoints", mount_path=HF_HOME_DIR),
],
on_start=load_models,
keep_warm_seconds=300
)
def predict_owlv2(context, base64_image, labels):
detector = context.on_start_value
im = PILImage.open(BytesIO(base64.b64decode(base64_image)))
predictions = detector(
im,
labels,
)
print(predictions)
return { "predictions": predictions }