From 9987782c1902b5e27700f57db3e60e620986e006 Mon Sep 17 00:00:00 2001 From: Filipe Almeida Date: Fri, 6 Oct 2023 16:28:46 -0700 Subject: [PATCH 1/2] For inference in model_worker, allow the device to be specified via a command line parameter. Right now it has only been tested with Apple Sillicon devices via the mps device. --- README.md | 2 ++ llava/model/builder.py | 4 ++-- llava/serve/model_worker.py | 11 +++++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index b49d01cb5..8921dd856 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,8 @@ You can launch as many workers as you want, and compare between different model python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port --worker http://localhost: --model-path ``` +I you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the --device flag: `--device mps`. + #### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB) If the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. Our latest code base will automatically try to use multiple GPUs if you have more than one GPU. You can specify which GPUs to use with `CUDA_VISIBLE_DEVICES`. Below is an example of running with the first two GPUs. diff --git a/llava/model/builder.py b/llava/model/builder.py index f0eb052d7..46d779365 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -23,7 +23,7 @@ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN -def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"): +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"): kwargs = {"device_map": device_map} if load_8bit: @@ -137,7 +137,7 @@ def load_from_hf(repo_id, filename, subfolder=None): vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() - vision_tower.to(device='cuda', dtype=torch.float16) + vision_tower.to(device=device, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): diff --git a/llava/serve/model_worker.py b/llava/serve/model_worker.py index 4308fde8a..a7bcd0829 100644 --- a/llava/serve/model_worker.py +++ b/llava/serve/model_worker.py @@ -45,7 +45,7 @@ class ModelWorker: def __init__(self, controller_addr, worker_addr, worker_id, no_register, model_path, model_base, model_name, - load_8bit, load_4bit): + load_8bit, load_4bit, device): self.controller_addr = controller_addr self.worker_addr = worker_addr self.worker_id = worker_id @@ -60,9 +60,10 @@ def __init__(self, controller_addr, worker_addr, else: self.model_name = model_name + self.device = device logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( - model_path, model_base, self.model_name, load_8bit, load_4bit) + model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device) self.is_multimodal = 'llava' in self.model_name.lower() if not no_register: @@ -159,7 +160,7 @@ def generate_stream(self, params): stop_str = params.get("stop", None) do_sample = True if temperature > 0.001 else False - input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) @@ -258,6 +259,7 @@ async def get_status(request: Request): parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--model-name", type=str) + parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") parser.add_argument("--limit-model-concurrency", type=int, default=5) parser.add_argument("--stream-interval", type=int, default=1) @@ -278,5 +280,6 @@ async def get_status(request: Request): args.model_base, args.model_name, args.load_8bit, - args.load_4bit) + args.load_4bit, + args.device) uvicorn.run(app, host=args.host, port=args.port, log_level="info") From f06186cd49326e9f4923ef641c2d962a238d97fa Mon Sep 17 00:00:00 2001 From: Haotian Liu <6631389+haotian-liu@users.noreply.github.com> Date: Sat, 7 Oct 2023 15:56:44 -1000 Subject: [PATCH 2/2] Fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8921dd856..767875b82 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ You can launch as many workers as you want, and compare between different model python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port --worker http://localhost: --model-path ``` -I you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the --device flag: `--device mps`. +If you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the `--device` flag: `--device mps`. #### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB)