Skip to content

Commit

Permalink
Merge pull request #479 from filipe-m-almeida/mps
Browse files Browse the repository at this point in the history
Support Apple Silicon devices for inference in model_worker.
  • Loading branch information
haotian-liu authored Oct 8, 2023
2 parents cec3511 + f06186c commit 0a7f494
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ You can launch as many workers as you want, and compare between different model
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port <different from 40000, say 40001> --worker http://localhost:<change accordingly, i.e. 40001> --model-path <ckpt2>
```

If you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the `--device` flag: `--device mps`.

#### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB)

If the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. Our latest code base will automatically try to use multiple GPUs if you have more than one GPU. You can specify which GPUs to use with `CUDA_VISIBLE_DEVICES`. Below is an example of running with the first two GPUs.
Expand Down
4 changes: 2 additions & 2 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
kwargs = {"device_map": device_map}

if load_8bit:
Expand Down Expand Up @@ -137,7 +137,7 @@ def load_from_hf(repo_id, filename, subfolder=None):
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device='cuda', dtype=torch.float16)
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
Expand Down
11 changes: 7 additions & 4 deletions llava/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ModelWorker:
def __init__(self, controller_addr, worker_addr,
worker_id, no_register,
model_path, model_base, model_name,
load_8bit, load_4bit):
load_8bit, load_4bit, device):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
Expand All @@ -60,9 +60,10 @@ def __init__(self, controller_addr, worker_addr,
else:
self.model_name = model_name

self.device = device
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, model_base, self.model_name, load_8bit, load_4bit)
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
self.is_multimodal = 'llava' in self.model_name.lower()

if not no_register:
Expand Down Expand Up @@ -159,7 +160,7 @@ def generate_stream(self, params):
stop_str = params.get("stop", None)
do_sample = True if temperature > 0.001 else False

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
Expand Down Expand Up @@ -258,6 +259,7 @@ async def get_status(request: Request):
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--model-name", type=str)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
Expand All @@ -278,5 +280,6 @@ async def get_status(request: Request):
args.model_base,
args.model_name,
args.load_8bit,
args.load_4bit)
args.load_4bit,
args.device)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

0 comments on commit 0a7f494

Please sign in to comment.