Skip to content

Commit eec5c69

Browse files
jeffreyjeffreywangjeffreyjeffreywangrichardliawgemini-code-assist[bot]nrghosh
authored
[docs][data][llm] Introduce docs for serve deployment processor and cross-node parallelism (ray-project#57261)
Signed-off-by: jeffreyjeffreywang <[email protected]> Signed-off-by: Richard Liaw <[email protected]> Co-authored-by: jeffreyjeffreywang <[email protected]> Co-authored-by: Richard Liaw <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Nikhil G <[email protected]>
1 parent 3cd8202 commit eec5c69

File tree

2 files changed

+176
-21
lines changed

2 files changed

+176
-21
lines changed

doc/source/data/doc_code/working-with-llms/basic_llm_example.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,135 @@ def create_embedding_processor():
197197

198198
# __embedding_config_example_end__
199199

200+
# __shared_vllm_engine_config_example_start__
201+
import ray
202+
from ray import serve
203+
from ray.data.llm import ServeDeploymentProcessorConfig, build_llm_processor
204+
from ray.serve.llm import (
205+
LLMConfig,
206+
ModelLoadingConfig,
207+
build_llm_deployment,
208+
)
209+
from ray.serve.llm.openai_api_models import CompletionRequest
210+
211+
llm_config = LLMConfig(
212+
model_loading_config=ModelLoadingConfig(
213+
model_id="facebook/opt-1.3b",
214+
model_source="facebook/opt-1.3b",
215+
),
216+
deployment_config=dict(
217+
name="demo_deployment_config",
218+
autoscaling_config=dict(
219+
min_replicas=1,
220+
max_replicas=1,
221+
),
222+
),
223+
engine_kwargs=dict(
224+
enable_prefix_caching=True,
225+
enable_chunked_prefill=True,
226+
max_num_batched_tokens=4096,
227+
),
228+
)
229+
230+
APP_NAME = "demo_app"
231+
DEPLOYMENT_NAME = "demo_deployment"
232+
override_serve_options = dict(name=DEPLOYMENT_NAME)
233+
234+
llm_app = build_llm_deployment(
235+
llm_config, override_serve_options=override_serve_options
236+
)
237+
app = serve.run(llm_app, name=APP_NAME)
238+
config = ServeDeploymentProcessorConfig(
239+
deployment_name=DEPLOYMENT_NAME,
240+
app_name=APP_NAME,
241+
dtype_mapping={
242+
"CompletionRequest": CompletionRequest,
243+
},
244+
concurrency=1,
245+
batch_size=64,
246+
)
247+
248+
processor1 = build_llm_processor(
249+
config,
250+
preprocess=lambda row: dict(
251+
method="completions",
252+
dtype="CompletionRequest",
253+
request_kwargs=dict(
254+
model="facebook/opt-1.3b",
255+
prompt=f"This is a prompt for {row['id']}",
256+
stream=False,
257+
),
258+
),
259+
postprocess=lambda row: dict(
260+
prompt=row["choices"][0]["text"],
261+
),
262+
)
263+
264+
processor2 = build_llm_processor(
265+
config,
266+
preprocess=lambda row: dict(
267+
method="completions",
268+
dtype="CompletionRequest",
269+
request_kwargs=dict(
270+
model="facebook/opt-1.3b",
271+
prompt=row["prompt"],
272+
stream=False,
273+
),
274+
),
275+
postprocess=lambda row: row,
276+
)
277+
278+
ds = ray.data.range(10)
279+
ds = processor2(processor1(ds))
280+
print(ds.take_all())
281+
# __shared_vllm_engine_config_example_end__
282+
283+
# __cross_node_parallelism_config_example_start__
284+
config = vLLMEngineProcessorConfig(
285+
model_source="unsloth/Llama-3.1-8B-Instruct",
286+
engine_kwargs={
287+
"enable_chunked_prefill": True,
288+
"max_num_batched_tokens": 4096,
289+
"max_model_len": 16384,
290+
"pipeline_parallel_size": 4,
291+
"tensor_parallel_size": 4,
292+
"distributed_executor_backend": "ray",
293+
},
294+
batch_size=32,
295+
concurrency=1,
296+
)
297+
# __cross_node_parallelism_config_example_end__
298+
299+
# __custom_placement_group_strategy_config_example_start__
300+
config = vLLMEngineProcessorConfig(
301+
model_source="unsloth/Llama-3.1-8B-Instruct",
302+
engine_kwargs={
303+
"enable_chunked_prefill": True,
304+
"max_num_batched_tokens": 4096,
305+
"max_model_len": 16384,
306+
"pipeline_parallel_size": 2,
307+
"tensor_parallel_size": 2,
308+
"distributed_executor_backend": "ray",
309+
},
310+
batch_size=32,
311+
concurrency=1,
312+
placement_group_config={
313+
"bundles": [{"GPU": 1}] * 4,
314+
"strategy": "STRICT_PACK",
315+
},
316+
)
317+
# __custom_placement_group_strategy_config_example_end__
318+
319+
# __concurrent_config_example_start__
320+
config = vLLMEngineProcessorConfig(
321+
model_source="unsloth/Llama-3.1-8B-Instruct",
322+
engine_kwargs={
323+
"enable_chunked_prefill": True,
324+
"max_num_batched_tokens": 4096,
325+
"max_model_len": 16384,
326+
},
327+
concurrency=10,
328+
batch_size=64,
329+
)
330+
# __concurrent_config_example_end__
200331
# __basic_llm_example_end__

doc/source/data/working-with-llms.rst

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,51 @@ You can also make calls to deployed models that have an OpenAI compatible API en
205205
:start-after: __openai_example_start__
206206
:end-before: __openai_example_end__
207207

208+
Batch inference with serve deployments
209+
---------------------------------------
210+
211+
You can configure any :ref:`serve deployment <converting-to-ray-serve-application>` for batch inference. This is particularly useful for multi-turn conversations,
212+
where you can use a shared vLLM engine across conversations. To achieve this, create an :ref:`LLM serve deployment <serving-llms>` and use
213+
the :class:`ServeDeploymentProcessorConfig <ray.data.llm.ServeDeploymentProcessorConfig>` class to configure the processor.
214+
215+
.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py
216+
:language: python
217+
:start-after: __shared_vllm_engine_config_example_start__
218+
:end-before: __shared_vllm_engine_config_example_end__
219+
220+
Cross-node parallelism
221+
---------------------------------------
222+
223+
Ray Data LLM supports cross-node parallelism, including tensor parallelism and pipeline parallelism.
224+
You can configure the parallelism level through the `engine_kwargs` argument in
225+
:class:`vLLMEngineProcessorConfig <ray.data.llm.vLLMEngineProcessorConfig>`. Use `ray` as the
226+
distributed executor backend to enable cross-node parallelism.
227+
228+
.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py
229+
:language: python
230+
:start-after: __cross_node_parallelism_config_example_start__
231+
:end-before: __cross_node_parallelism_config_example_end__
232+
233+
234+
In addition, you can customize the placement group strategy to control how Ray places vLLM engine workers across nodes.
235+
While you can specify the degree of tensor and pipeline parallelism, the specific assignment of model ranks to GPUs is managed by the vLLM engine and you can't directly configure it through the Ray Data LLM API.
236+
237+
.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py
238+
:language: python
239+
:start-after: __custom_placement_group_strategy_config_example_start__
240+
:end-before: __custom_placement_group_strategy_config_example_end__
241+
242+
Besides cross-node parallelism, you can also horizontally scale the LLM stage to multiple nodes.
243+
Configure the number of replicas with the `concurrency` argument in
244+
:class:`vLLMEngineProcessorConfig <ray.data.llm.vLLMEngineProcessorConfig>`.
245+
246+
.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py
247+
:language: python
248+
:start-after: __concurrent_config_example_start__
249+
:end-before: __concurrent_config_example_end__
250+
251+
252+
208253
Usage Data Collection
209254
--------------------------
210255

@@ -227,27 +272,6 @@ to turn it off.
227272
Frequently Asked Questions (FAQs)
228273
--------------------------------------------------
229274

230-
.. TODO(#55491): Rewrite this section once the restriction is lifted.
231-
.. TODO(#55405): Cross-node TP in progress.
232-
.. _cross_node_parallelism:
233-
234-
How to configure LLM stage to parallelize across multiple nodes?
235-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
236-
237-
At the moment, Ray Data LLM doesn't support cross-node parallelism (either
238-
tensor parallelism or pipeline parallelism).
239-
240-
The processing pipeline is designed to run on a single node. The number of
241-
GPUs is calculated as the product of the tensor parallel size and the pipeline
242-
parallel size, and apply
243-
[`STRICT_PACK` strategy](https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html#pgroup-strategy)
244-
to ensure that each replica of the LLM stage is executed on a single node.
245-
246-
Nevertheless, you can still horizontally scale the LLM stage to multiple nodes
247-
as long as each replica (TP * PP) fits into a single node. The number of
248-
replicas is configured by the `concurrency` argument in
249-
:class:`vLLMEngineProcessorConfig <ray.data.llm.vLLMEngineProcessorConfig>`.
250-
251275
.. _gpu_memory_management:
252276

253277
GPU Memory Management and CUDA OOM Prevention

0 commit comments

Comments
 (0)