diff --git a/chatlearn/models/vllm/hooks/__init__.py b/chatlearn/models/vllm/hooks/__init__.py index 2a4036b..9bedfe5 100644 --- a/chatlearn/models/vllm/hooks/__init__.py +++ b/chatlearn/models/vllm/hooks/__init__.py @@ -25,6 +25,8 @@ from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: from chatlearn.models.vllm.hooks import input_preprocess + from chatlearn.models.vllm.hooks import async_llm_engine + from chatlearn.models.vllm.hooks import llm from chatlearn.models.vllm.hooks import loader else: if importlib.util.find_spec("vllm"): diff --git a/chatlearn/models/vllm/hooks/async_llm_engine.py b/chatlearn/models/vllm/hooks/async_llm_engine.py new file mode 100644 index 0000000..4542824 --- /dev/null +++ b/chatlearn/models/vllm/hooks/async_llm_engine.py @@ -0,0 +1,54 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 del init_ray_cluster in AsyncLLMEngine.""" + +from typing import Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument,not-callable +from vllm.config import EngineConfig +from vllm.engine import async_llm_engine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.metrics_types import StatLoggerBase +from vllm.usage.usage_lib import UsageContext + +@classmethod +def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, +) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + if engine_config is None: + engine_config = engine_args.create_engine_config() + + executor_class = cls._get_executor_cls(engine_config) + + # Create the async LLM engine. + engine = cls( + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + +async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/llm.py b/chatlearn/models/vllm/hooks/llm.py new file mode 100644 index 0000000..8824e94 --- /dev/null +++ b/chatlearn/models/vllm/hooks/llm.py @@ -0,0 +1,87 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 llm init with AsyncLLMEngine and AsyncEngineArgs.""" + +from typing import Any, Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints import llm +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter + +def init( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, +) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + mm_processor_kwargs=mm_processor_kwargs, + **kwargs, + ) + + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS).engine + self.request_counter = Counter() + +llm.LLM.__init__ = init diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 2daf2cc..8da1c9a 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -14,22 +14,15 @@ # ============================================================================== """VLLM module v2""" -import asyncio import inspect import os -import sys -from typing import Optional import torch from transformers import AutoTokenizer from vllm import SamplingParams -from vllm.config import EngineConfig from vllm.config import LoadFormat +from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import RayWorkerWrapper -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser from chatlearn.utils.global_vars import set_vllm_actors from chatlearn.utils.vllm_import_helper import TextTokensPrompt @@ -46,7 +39,7 @@ def __init__(self, *args, **kwargs): if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() - self.engine = None + self.llm_engine = None self.tokenizer = None def setup(self): @@ -56,87 +49,47 @@ def setup(self): tokenizer.tokenizer = tokenizer self.tokenizer = tokenizer - def _init_args(self, args): - # scheduler config - args.max_num_seqs = self.module_args.generation_batch_size - args.max_num_batched_tokens = self.model_args.get("max_num_batched_tokens") - args.num_scheduler_steps = self.model_args.get("num_scheduler_steps", 1) - - # model config - args.max_seq_len = self.model_args.get("seq_length") - - # logger config - args.disable_log_requests = True - - # load format: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. - args.load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) - if args.load_format == LoadFormat.DUMMY: - args.model_loader_extra_config = self.model_args - self.model_args["need_load_ckpt"] = self.src_parameter_model is None - - # engine config - args.enforce_eager = self.model_args.get("enforce_eager", False) - def setup_vllm(self, workers): # setup vllm engine in rank 0 os.environ['VLLM_HOST_IP'] = self.get_address() set_vllm_actors(workers) - parser = FlexibleArgumentParser() - parser = AsyncEngineArgs.add_cli_args(parser) - backup_sys_argv = sys.argv + dtype = self.model_args.get("dtype", "bfloat16") if self.model_args.get("fp16", False): dtype = "float16" - vllm_sys_argv = ["", - f"--model={self.model_args['tokenizer']}", - f"--tensor_parallel_size={self.module_args.tensor_model_parallel_size}", - f"--pipeline_parallel_size={self.module_args.pipeline_model_parallel_size}", - f"--dtype={dtype}", - "--worker_use_ray", - "--disable_custom_all_reduce"] - sys.argv = vllm_sys_argv - args = parser.parse_args() - self._init_args(args) - engine_args = AsyncEngineArgs.from_cli_args(args) - self.engine = self.from_engine_args(engine_args) - - sys.argv = backup_sys_argv - self.tokenizer = self.engine.engine.tokenizer - def from_engine_args( - self, - engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers = None, - ) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - # Create the engine configs. - - if engine_config is None: - engine_config = engine_args.create_engine_config() - - executor_class = AsyncLLMEngine._get_executor_cls(engine_config) - - # Create the async LLM engine. - engine = AsyncLLMEngine( - **engine_config.to_dict(), - executor_class=executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - start_engine_loop=start_engine_loop, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine - - async def generate_one_sample(self, prompt, sampling_param, request_id): - results_generator = self.engine.generate(prompt, sampling_param, request_id) - final_output = None - async for request_output in results_generator: - final_output = request_output - return final_output + load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) + if load_format == LoadFormat.DUMMY: + self.model_args["need_load_ckpt"] = self.src_parameter_model is None + model_loader_extra_config = self.model_args + else: + model_loader_extra_config = None + + self.llm = LLM( + model=self.model_args['tokenizer'], + tokenizer=self.model_args['tokenizer'], + max_seq_len_to_capture=self.model_args.get("seq_length"), + # load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. + load_format=load_format, + model_loader_extra_config=model_loader_extra_config, + # parallelism strategy + tensor_parallel_size=self.module_args.tensor_model_parallel_size, + pipeline_parallel_size=self.module_args.pipeline_model_parallel_size, + dtype=dtype, + # scheduling strategy + max_num_seqs=self.module_args.generation_batch_size, + max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None), + num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1), + gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90), + # logger + disable_log_requests=self.model_args.get("disable_log_requests", True), + disable_log_stats=self.model_args.get("disable_log_stats", True), + trust_remote_code=True, + # TODO(jiangle.jl): support non-eager mode. + enforce_eager=True, + disable_custom_all_reduce=True, + distributed_executor_backend="ray") + self.tokenizer = self.llm.llm_engine.tokenizer def _get_sampling_params(self, is_eval): temperature = 0.0 @@ -173,7 +126,7 @@ def _get_sampling_params(self, is_eval): sampling_params.use_beam_search = self.model_args.get("use_beam_search") return sampling_params - def convert_v1_inputs(self, prompts, prompt_token_ids): + def _convert_v1_inputs(self, prompts, prompt_token_ids): num_requests = len(prompts) assert num_requests == len(prompt_token_ids), \ ("The lengths of prompts and prompt_token_ids must be the same.") @@ -201,9 +154,9 @@ async def generate_vllm(self, query, is_eval): prompts_token_ids = query[input_ids_key] seq_len = self.model_args.get("seq_length") final_outputs = [] - tasks = [] + parsed_prompts = [] + sampling_params = [] for i, prompt in enumerate(prompts): - request_id = i prompt_token_ids = prompts_token_ids[i] if 'sampling_param' in query: sampling_param = query['sampling_param'][i] @@ -215,14 +168,18 @@ async def generate_vllm(self, query, is_eval): max_tokens = self.model_args.get("max_new_tokens") assert max_tokens < seq_len, "max_new_tokens must less than seq length." sampling_param.max_tokens = max_tokens - inputs = self.convert_v1_inputs( + item = self._convert_v1_inputs( prompts=[prompt], prompt_token_ids=[prompt_token_ids], )[0] + parsed_prompts.append(item) + sampling_params.append(sampling_param) - task = asyncio.create_task(self.generate_one_sample(inputs, sampling_param, request_id)) - tasks.append(task) - outputs = await asyncio.gather(*tasks) + outputs = self.llm.generate( + parsed_prompts, + sampling_params, + use_tqdm=True, + ) final_outputs = sorted(outputs, key=lambda x: int(x.request_id)) return final_outputs @@ -253,7 +210,7 @@ def __init__(self, *args, **kwargs): if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() - self.engine = None + self.llm_engine = None def peak_memory(self): """ diff --git a/chatlearn/runtime/decorator.py b/chatlearn/runtime/decorator.py index 0291ee7..56d8997 100644 --- a/chatlearn/runtime/decorator.py +++ b/chatlearn/runtime/decorator.py @@ -164,7 +164,7 @@ def get_kwarg(key): # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ - or isinstance(self, VLLMModuleV2): + or isinstance(self, VLLMModuleV2): final_results = concat_along_batch(results) else: if 'iteration' in inspect.signature(func).parameters: @@ -176,7 +176,7 @@ def get_kwarg(key): # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ - or isinstance(self, VLLMModuleV2): + or isinstance(self, VLLMModuleV2): final_results = ret else: if 'iteration' in inspect.signature(func).parameters: