Skip to content

Commit 3084283

Browse files
authored
[Router] Support Batch API part 2 (#109)
* feat: Support batch Signed-off-by: Ce Gao <[email protected]> * fix: Add TODO Signed-off-by: Ce Gao <[email protected]> --------- Signed-off-by: Ce Gao <[email protected]>
1 parent aec9264 commit 3084283

File tree

8 files changed

+586
-51
lines changed

8 files changed

+586
-51
lines changed

examples/batch.py

-42
This file was deleted.

examples/openai_api_client_batch.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
This script uploads JSONL files to the server, which can be used to run
3+
batch inference on the VLLM model.
4+
"""
5+
6+
import argparse
7+
import time
8+
from pathlib import Path
9+
10+
import rich
11+
from openai import OpenAI
12+
13+
# get the current directory
14+
current_dir = Path(__file__).parent
15+
16+
if __name__ == "__main__":
17+
parser = argparse.ArgumentParser(
18+
description="CLI arguments for OpenAI API configuration."
19+
)
20+
parser.add_argument(
21+
"--openai-api-key", type=str, default="NULL", help="Your OpenAI API key"
22+
)
23+
parser.add_argument(
24+
"--openai-api-base",
25+
type=str,
26+
default="http://localhost:8000/v1",
27+
help="Base URL for OpenAI API",
28+
)
29+
parser.add_argument(
30+
"--file-path",
31+
type=str,
32+
default="batch.jsonl",
33+
help="Path to the JSONL file to upload",
34+
)
35+
args = parser.parse_args()
36+
37+
openai_api_key = args.openai_api_key
38+
openai_api_base = args.openai_api_base
39+
40+
# generate this file using `./generate_file.sh`
41+
filepath = current_dir / args.file_path
42+
43+
client = OpenAI(
44+
api_key=openai_api_key,
45+
base_url=openai_api_base,
46+
)
47+
48+
file = client.files.create(
49+
file=filepath.read_bytes(),
50+
purpose="batch",
51+
)
52+
53+
# get the file according to the file id
54+
retrieved = client.files.retrieve(file.id)
55+
print("Retrieved file:")
56+
rich.print(retrieved)
57+
58+
file_content = client.files.content(file.id)
59+
print("File content:")
60+
rich.print(file_content.read().decode())
61+
file_content.close()
62+
63+
# create a batch job
64+
batch = client.batches.create(
65+
input_file_id=file.id,
66+
endpoint="/completions",
67+
completion_window="1h",
68+
)
69+
print("Created batch job:")
70+
rich.print(batch)
71+
72+
# retrieve the batch job
73+
retrieved_batch = client.batches.retrieve(batch.id)
74+
print("Retrieved batch job:")
75+
rich.print(retrieved_batch)
76+
77+
# list all batch jobs
78+
batches = client.batches.list()
79+
print("List of batch jobs:")
80+
rich.print(batches)
81+
82+
# wait for the batch job to complete
83+
while retrieved_batch.status == "pending":
84+
time.sleep(5)
85+
retrieved_batch = client.batches.retrieve(batch.id)
86+
87+
# get the output file content
88+
output_file = client.files.retrieve(retrieved_batch.output_file_id)
89+
print("Output file:")
90+
rich.print(output_file)
91+
92+
output_file_content = client.files.content(output_file.id)
93+
print("Output file content:")
94+
rich.print(output_file_content.read().decode())

src/vllm_router/batch/__init__.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from vllm_router.batch.batch import BatchEndpoint, BatchInfo, BatchRequest, BatchStatus
2+
from vllm_router.batch.processor import BatchProcessor
3+
from vllm_router.files import Storage
4+
5+
6+
def initialize_batch_processor(
7+
batch_processor_name: str, storage_path: str, storage: Storage
8+
) -> BatchProcessor:
9+
if batch_processor_name == "local":
10+
from vllm_router.batch.local_processor import LocalBatchProcessor
11+
12+
return LocalBatchProcessor(storage_path, storage)
13+
else:
14+
raise ValueError(f"Unknown batch processor: {batch_processor_name}")
15+
16+
17+
__all__ = [
18+
"BatchEndpoint",
19+
"BatchInfo",
20+
"BatchRequest",
21+
"BatchStatus",
22+
"BatchProcessor",
23+
"initialize_batch_processor",
24+
]

src/vllm_router/batch/batch.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import Any, Dict, Optional
4+
5+
6+
class BatchStatus(str, Enum):
7+
"""
8+
Represents the status of a batch job.
9+
"""
10+
11+
PENDING = "pending"
12+
RUNNING = "running"
13+
COMPLETED = "completed"
14+
FAILED = "failed"
15+
16+
17+
class BatchEndpoint(str, Enum):
18+
"""
19+
Represents the available OpenAI API endpoints for batch requests.
20+
21+
Ref https://platform.openai.com/docs/api-reference/batch/create#batch-create-endpoint.
22+
"""
23+
24+
CHAT_COMPLETION = "/v1/chat/completions"
25+
EMBEDDING = "/v1/embeddings"
26+
COMPLETION = "/v1/completions"
27+
28+
29+
@dataclass
30+
class BatchRequest:
31+
"""Represents a single request in a batch"""
32+
33+
input_file_id: str
34+
endpoint: BatchEndpoint
35+
completion_window: str
36+
metadata: Optional[Dict[str, Any]] = None
37+
38+
39+
@dataclass
40+
class BatchInfo:
41+
"""
42+
Represents batch job information
43+
44+
Ref https://platform.openai.com/docs/api-reference/batch/object
45+
"""
46+
47+
id: str
48+
status: BatchStatus
49+
input_file_id: str
50+
created_at: int
51+
endpoint: str
52+
completion_window: str
53+
output_file_id: Optional[str] = None
54+
error_file_id: Optional[str] = None
55+
in_progress_at: Optional[int] = None
56+
expires_at: Optional[int] = None
57+
finalizing_at: Optional[int] = None
58+
completed_at: Optional[int] = None
59+
failed_at: Optional[int] = None
60+
expired_at: Optional[int] = None
61+
cancelling_at: Optional[int] = None
62+
cancelled_at: Optional[int] = None
63+
total_requests: Optional[int] = None
64+
completed_requests: int = 0
65+
failed_requests: int = 0
66+
metadata: Optional[Dict[str, Any]] = None
67+
68+
def to_dict(self) -> Dict[str, Any]:
69+
"""Convert the instance to a dictionary."""
70+
return {
71+
"id": self.id,
72+
"status": self.status.value,
73+
"input_file_id": self.input_file_id,
74+
"created_at": self.created_at,
75+
"endpoint": self.endpoint,
76+
"completion_window": self.completion_window,
77+
"output_file_id": self.output_file_id,
78+
"error_file_id": self.error_file_id,
79+
"in_progress_at": self.in_progress_at,
80+
"expires_at": self.expires_at,
81+
"finalizing_at": self.finalizing_at,
82+
"completed_at": self.completed_at,
83+
"failed_at": self.failed_at,
84+
"expired_at": self.expired_at,
85+
"cancelling_at": self.cancelling_at,
86+
"cancelled_at": self.cancelled_at,
87+
"total_requests": self.total_requests,
88+
"completed_requests": self.completed_requests,
89+
"failed_requests": self.failed_requests,
90+
"metadata": self.metadata,
91+
}

0 commit comments

Comments
 (0)