Skip to content

Commit

Permalink
support llm.generate in vllm_module_v2. (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 authored Dec 24, 2024
1 parent 36c2794 commit bd75b69
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 92 deletions.
2 changes: 2 additions & 0 deletions chatlearn/models/vllm/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from chatlearn.models.vllm.hooks import input_preprocess
from chatlearn.models.vllm.hooks import async_llm_engine
from chatlearn.models.vllm.hooks import llm
from chatlearn.models.vllm.hooks import loader
else:
if importlib.util.find_spec("vllm"):
Expand Down
54 changes: 54 additions & 0 deletions chatlearn/models/vllm/hooks/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.6.3 del init_ray_cluster in AsyncLLMEngine."""

from typing import Dict, Optional

# pylint: disable=unused-import,wildcard-import,unused-argument,not-callable
from vllm.config import EngineConfig
from vllm.engine import async_llm_engine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.usage.usage_lib import UsageContext

@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
if engine_config is None:
engine_config = engine_args.create_engine_config()

executor_class = cls._get_executor_cls(engine_config)

# Create the async LLM engine.
engine = cls(
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine

async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args
87 changes: 87 additions & 0 deletions chatlearn/models/vllm/hooks/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.6.3 llm init with AsyncLLMEngine and AsyncEngineArgs."""

from typing import Any, Dict, Optional

# pylint: disable=unused-import,wildcard-import,unused-argument
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints import llm
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter

def init(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
mm_processor_kwargs=mm_processor_kwargs,
**kwargs,
)

self.llm_engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS).engine
self.request_counter = Counter()

llm.LLM.__init__ = init
137 changes: 47 additions & 90 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,15 @@
# ==============================================================================
"""VLLM module v2"""

import asyncio
import inspect
import os
import sys
from typing import Optional

import torch
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.config import EngineConfig
from vllm.config import LoadFormat
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import RayWorkerWrapper
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser

from chatlearn.utils.global_vars import set_vllm_actors
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
Expand All @@ -46,7 +39,7 @@ def __init__(self, *args, **kwargs):
if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
os.environ['VLLM_HOST_IP'] = self.get_address()
self.engine = None
self.llm_engine = None
self.tokenizer = None

def setup(self):
Expand All @@ -56,87 +49,47 @@ def setup(self):
tokenizer.tokenizer = tokenizer
self.tokenizer = tokenizer

def _init_args(self, args):
# scheduler config
args.max_num_seqs = self.module_args.generation_batch_size
args.max_num_batched_tokens = self.model_args.get("max_num_batched_tokens")
args.num_scheduler_steps = self.model_args.get("num_scheduler_steps", 1)

# model config
args.max_seq_len = self.model_args.get("seq_length")

# logger config
args.disable_log_requests = True

# load format: 'dummy' for megatron ckpt or mock weight; others for hf ckpt.
args.load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY)
if args.load_format == LoadFormat.DUMMY:
args.model_loader_extra_config = self.model_args
self.model_args["need_load_ckpt"] = self.src_parameter_model is None

# engine config
args.enforce_eager = self.model_args.get("enforce_eager", False)

def setup_vllm(self, workers):
# setup vllm engine in rank 0
os.environ['VLLM_HOST_IP'] = self.get_address()
set_vllm_actors(workers)
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
backup_sys_argv = sys.argv

dtype = self.model_args.get("dtype", "bfloat16")
if self.model_args.get("fp16", False):
dtype = "float16"
vllm_sys_argv = ["",
f"--model={self.model_args['tokenizer']}",
f"--tensor_parallel_size={self.module_args.tensor_model_parallel_size}",
f"--pipeline_parallel_size={self.module_args.pipeline_model_parallel_size}",
f"--dtype={dtype}",
"--worker_use_ray",
"--disable_custom_all_reduce"]
sys.argv = vllm_sys_argv
args = parser.parse_args()
self._init_args(args)
engine_args = AsyncEngineArgs.from_cli_args(args)
self.engine = self.from_engine_args(engine_args)

sys.argv = backup_sys_argv
self.tokenizer = self.engine.engine.tokenizer

def from_engine_args(
self,
engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.

if engine_config is None:
engine_config = engine_args.create_engine_config()

executor_class = AsyncLLMEngine._get_executor_cls(engine_config)

# Create the async LLM engine.
engine = AsyncLLMEngine(
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine

async def generate_one_sample(self, prompt, sampling_param, request_id):
results_generator = self.engine.generate(prompt, sampling_param, request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
return final_output
load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY)
if load_format == LoadFormat.DUMMY:
self.model_args["need_load_ckpt"] = self.src_parameter_model is None
model_loader_extra_config = self.model_args
else:
model_loader_extra_config = None

self.llm = LLM(
model=self.model_args['tokenizer'],
tokenizer=self.model_args['tokenizer'],
max_seq_len_to_capture=self.model_args.get("seq_length"),
# load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt.
load_format=load_format,
model_loader_extra_config=model_loader_extra_config,
# parallelism strategy
tensor_parallel_size=self.module_args.tensor_model_parallel_size,
pipeline_parallel_size=self.module_args.pipeline_model_parallel_size,
dtype=dtype,
# scheduling strategy
max_num_seqs=self.module_args.generation_batch_size,
max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None),
num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1),
gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90),
# logger
disable_log_requests=self.model_args.get("disable_log_requests", True),
disable_log_stats=self.model_args.get("disable_log_stats", True),
trust_remote_code=True,
# TODO(jiangle.jl): support non-eager mode.
enforce_eager=True,
disable_custom_all_reduce=True,
distributed_executor_backend="ray")
self.tokenizer = self.llm.llm_engine.tokenizer

def _get_sampling_params(self, is_eval):
temperature = 0.0
Expand Down Expand Up @@ -173,7 +126,7 @@ def _get_sampling_params(self, is_eval):
sampling_params.use_beam_search = self.model_args.get("use_beam_search")
return sampling_params

def convert_v1_inputs(self, prompts, prompt_token_ids):
def _convert_v1_inputs(self, prompts, prompt_token_ids):
num_requests = len(prompts)
assert num_requests == len(prompt_token_ids), \
("The lengths of prompts and prompt_token_ids must be the same.")
Expand Down Expand Up @@ -201,9 +154,9 @@ async def generate_vllm(self, query, is_eval):
prompts_token_ids = query[input_ids_key]
seq_len = self.model_args.get("seq_length")
final_outputs = []
tasks = []
parsed_prompts = []
sampling_params = []
for i, prompt in enumerate(prompts):
request_id = i
prompt_token_ids = prompts_token_ids[i]
if 'sampling_param' in query:
sampling_param = query['sampling_param'][i]
Expand All @@ -215,14 +168,18 @@ async def generate_vllm(self, query, is_eval):
max_tokens = self.model_args.get("max_new_tokens")
assert max_tokens < seq_len, "max_new_tokens must less than seq length."
sampling_param.max_tokens = max_tokens
inputs = self.convert_v1_inputs(
item = self._convert_v1_inputs(
prompts=[prompt],
prompt_token_ids=[prompt_token_ids],
)[0]
parsed_prompts.append(item)
sampling_params.append(sampling_param)

task = asyncio.create_task(self.generate_one_sample(inputs, sampling_param, request_id))
tasks.append(task)
outputs = await asyncio.gather(*tasks)
outputs = self.llm.generate(
parsed_prompts,
sampling_params,
use_tqdm=True,
)
final_outputs = sorted(outputs, key=lambda x: int(x.request_id))
return final_outputs

Expand Down Expand Up @@ -253,7 +210,7 @@ def __init__(self, *args, **kwargs):
if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
os.environ['VLLM_HOST_IP'] = self.get_address()
self.engine = None
self.llm_engine = None

def peak_memory(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions chatlearn/runtime/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_kwarg(key):
# for model with DP/EP, we need to return results from all ranks
# for model with TP/PP, only return the results from last rank
if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \
or isinstance(self, VLLMModuleV2):
or isinstance(self, VLLMModuleV2):
final_results = concat_along_batch(results)
else:
if 'iteration' in inspect.signature(func).parameters:
Expand All @@ -176,7 +176,7 @@ def get_kwarg(key):
# for model with DP/EP, we need to return results from all ranks
# for model with TP/PP, only return the results from last rank
if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \
or isinstance(self, VLLMModuleV2):
or isinstance(self, VLLMModuleV2):
final_results = ret
else:
if 'iteration' in inspect.signature(func).parameters:
Expand Down

0 comments on commit bd75b69

Please sign in to comment.