Skip to content

Commit

Permalink
[LLM Batch][5/N] vLLM Engine Processor (ray-project#50494)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Feb 18, 2025
1 parent e37e942 commit 5ab0e70
Show file tree
Hide file tree
Showing 24 changed files with 785 additions and 124 deletions.
1 change: 1 addition & 0 deletions doc/source/data/api/llm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ Processor Configs

~ProcessorConfig
~HttpRequestProcessorConfig
~vLLMEngineProcessorConfig
72 changes: 68 additions & 4 deletions python/ray/data/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional
from ray.data.block import UserDefinedFunction
from ray.llm._internal.batch.processor import (
ProcessorConfig as _ProcessorConfig,
Processor,
HttpRequestProcessorConfig as _HttpRequestProcessorConfig,
vLLMEngineProcessorConfig as _vLLMEngineProcessorConfig,
)
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -55,25 +58,86 @@ class HttpRequestProcessorConfig(_HttpRequestProcessorConfig):


@PublicAPI(stability="alpha")
def build_llm_processor(config: ProcessorConfig, **kwargs) -> Processor:
class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
"""The configuration for the vLLM engine processor.
Examples:
.. testcode::
:skipif: True
import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
config = vLLMEngineProcessorConfig(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
engine_kwargs=dict(
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_num_batched_tokens=4096,
),
accelerator_type="L4",
concurrency=1,
batch_size=64,
)
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a calculator"},
{"role": "user", "content": f"{row['id']} ** 3 = ?"},
],
sampling_params=dict(
temperature=0.3,
max_tokens=20,
detokenize=False,
),
),
postprocess=lambda row: dict(
resp=row["generated_text"],
),
)
ds = ray.data.range(300)
ds = processor(ds)
for row in ds.take_all():
print(row)
"""

pass


@PublicAPI(stability="alpha")
def build_llm_processor(
config: ProcessorConfig,
preprocess: Optional[UserDefinedFunction] = None,
postprocess: Optional[UserDefinedFunction] = None,
) -> Processor:
"""Build a LLM processor using the given config.
Args:
config: The processor config.
**kwargs: Additional keyword arguments to pass to the processor.
See `Processor` for argument details.
preprocess: An optional lambda function that takes a row (dict) as input
and returns a preprocessed row (dict). The output row must contain the
required fields for the following processing stages.
postprocess: An optional lambda function that takes a row (dict) as input
and returns a postprocessed row (dict).
Returns:
The built processor.
"""
from ray.llm._internal.batch.processor import ProcessorBuilder

return ProcessorBuilder.build(config, **kwargs)
return ProcessorBuilder.build(
config,
preprocess=preprocess,
postprocess=postprocess,
)


__all__ = [
"ProcessorConfig",
"Processor",
"HttpRequestProcessorConfig",
"vLLMEngineProcessorConfig",
"build_llm_processor",
]
2 changes: 2 additions & 0 deletions python/ray/llm/_internal/batch/processor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base import ProcessorConfig, ProcessorBuilder, Processor
from .http_request_proc import HttpRequestProcessorConfig
from .vllm_engine_proc import vLLMEngineProcessorConfig

__all__ = [
"ProcessorConfig",
"ProcessorBuilder",
"HttpRequestProcessorConfig",
"vLLMEngineProcessorConfig",
"Processor",
]
8 changes: 4 additions & 4 deletions python/ray/llm/_internal/batch/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Processor:
# The internal used data column name ("__data"). Your input
# dataset should not contain this column. If you want to use this column
# in your input dataset, you have to derive and customize Processor.
data_column: str = "__data"
DATA_COLUMN: str = "__data"

def __init__(
self,
Expand All @@ -79,12 +79,12 @@ def __init__(

self.preprocess = wrap_preprocess(
preprocess,
self.data_column,
self.DATA_COLUMN,
)

self.postprocess = wrap_postprocess(
postprocess,
self.data_column,
self.DATA_COLUMN,
)

for stage in stages:
Expand All @@ -108,7 +108,7 @@ def __call__(self, dataset: Dataset) -> Dataset:
for stage in self.stages.values():
kwargs = stage.get_dataset_map_batches_kwargs(
batch_size=self.config.batch_size,
data_column=self.data_column,
data_column=self.DATA_COLUMN,
)
dataset = dataset.map_batches(stage.fn, **kwargs)

Expand Down
19 changes: 16 additions & 3 deletions python/ray/llm/_internal/batch/processor/http_request_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import Field

from ray.data.block import UserDefinedFunction

from ray.llm._internal.batch.processor.base import (
Processor,
ProcessorConfig,
Expand Down Expand Up @@ -36,13 +38,19 @@ class HttpRequestProcessorConfig(ProcessorConfig):


def build_http_request_processor(
config: HttpRequestProcessorConfig, **kwargs
config: HttpRequestProcessorConfig,
preprocess: Optional[UserDefinedFunction] = None,
postprocess: Optional[UserDefinedFunction] = None,
) -> Processor:
"""Construct a Processor and configure stages.
Args:
config: The configuration for the processor.
**kwargs: The keyword arguments for the processor.
preprocess: An optional lambda function that takes a row (dict) as input
and returns a preprocessed row (dict). The output row must contain the
required fields for the following processing stages.
postprocess: An optional lambda function that takes a row (dict) as input
and returns a postprocessed row (dict).
Returns:
The constructed processor.
Expand All @@ -59,7 +67,12 @@ def build_http_request_processor(
),
)
]
processor = Processor(config, stages, **kwargs)
processor = Processor(
config,
stages,
preprocess=preprocess,
postprocess=postprocess,
)
return processor


Expand Down
197 changes: 197 additions & 0 deletions python/ray/llm/_internal/batch/processor/vllm_engine_proc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""The vLLM engine processor."""
from typing import Any, Dict, Optional

from pydantic import Field, root_validator

from ray.data.block import UserDefinedFunction

import ray
from ray.llm._internal.batch.processor.base import (
Processor,
ProcessorConfig,
ProcessorBuilder,
)
from ray.llm._internal.batch.stages import (
vLLMEngineStage,
ChatTemplateStage,
PrepareImageStage,
TokenizeStage,
DetokenizeStage,
)
from ray.llm._internal.batch.stages.vllm_engine_stage import vLLMTaskType


class vLLMEngineProcessorConfig(ProcessorConfig):
"""The configuration for the vLLM engine processor."""

# vLLM stage configurations.
model: str = Field(
description="The model to use for the vLLM engine.",
)
engine_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="The kwargs to pass to the vLLM engine.",
)
task_type: vLLMTaskType = Field(
default=vLLMTaskType.GENERATE,
description="The task type to use. If not specified, will use "
"'generate' by default.",
)
runtime_env: Optional[Dict[str, Any]] = Field(
default=None,
description="The runtime environment to use for the vLLM engine.",
)
max_pending_requests: Optional[int] = Field(
default=None,
description="The maximum number of pending requests. If not specified, "
"will use the default value from the vLLM engine.",
)
max_concurrent_batches: int = Field(
default=4,
description="The maximum number of concurrent batches in the engine. "
"This is to overlap the batch processing to avoid the tail latency of "
"each batch. The default value may not be optimal when the batch size "
"or the batch processing latency is too small, but it should be good "
"enough for batch size >= 64.",
)

# Processor stage configurations.
apply_chat_template: bool = Field(
default=True, description="Whether to apply chat template."
)
chat_template: Optional[str] = Field(
default=None,
description="The chat template to use. This is usually not needed if the "
"model checkpoint already contains the chat template.",
)
tokenize: bool = Field(
default=True,
description="Whether to tokenize the input before passing it to the "
"vLLM engine. If not, vLLM will tokenize the prompt in the engine.",
)
detokenize: bool = Field(
default=True,
description="Whether to detokenize the output.",
)
has_image: bool = Field(
default=False,
description="Whether the input messages have images.",
)

@root_validator(pre=True)
def validate_task_type(cls, values):
task_type_str = values.get("task_type", "generate")
values["task_type"] = vLLMTaskType(task_type_str)
return values


def build_vllm_engine_processor(
config: vLLMEngineProcessorConfig,
preprocess: Optional[UserDefinedFunction] = None,
postprocess: Optional[UserDefinedFunction] = None,
) -> Processor:
"""Construct a Processor and configure stages.
Args:
config: The configuration for the processor.
preprocess: An optional lambda function that takes a row (dict) as input
and returns a preprocessed row (dict). The output row must contain the
required fields for the following processing stages.
postprocess: An optional lambda function that takes a row (dict) as input
and returns a postprocessed row (dict).
Returns:
The constructed processor.
"""
ray.init(runtime_env=config.runtime_env, ignore_reinit_error=True)

stages = []

if config.has_image:
stages.append(
PrepareImageStage(
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=(1, config.concurrency),
batch_size=config.batch_size,
),
)
)
if config.apply_chat_template:
stages.append(
ChatTemplateStage(
fn_constructor_kwargs=dict(
model=config.model,
chat_template=config.chat_template,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=(1, config.concurrency),
batch_size=config.batch_size,
),
)
)

if config.tokenize:
stages.append(
TokenizeStage(
fn_constructor_kwargs=dict(
model=config.model,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=(1, config.concurrency),
batch_size=config.batch_size,
),
)
)

# Core stage -- the vLLM engine.

stages.append(
vLLMEngineStage(
fn_constructor_kwargs=dict(
model=config.model,
engine_kwargs=config.engine_kwargs,
task_type=config.task_type,
max_pending_requests=config.max_pending_requests,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
# The number of running replicas. Note that we use a single
# integer to let Ray Data prepare all replicas before kicking
# off the processing for now.
concurrency=config.concurrency,
# The number of running batches "per actor" in Ray Core level.
# This is used to make sure we overlap batches to avoid the tail
# latency of each batch.
max_concurrency=config.max_concurrent_batches,
accelerator_type=config.accelerator_type,
runtime_env=config.runtime_env,
),
)
)

if config.detokenize:
stages.append(
DetokenizeStage(
fn_constructor_kwargs=dict(
model=config.model,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=(1, config.concurrency),
batch_size=config.batch_size,
),
)
)

processor = Processor(
config,
stages,
preprocess=preprocess,
postprocess=postprocess,
)
return processor


ProcessorBuilder.register(vLLMEngineProcessorConfig, build_vllm_engine_processor)
Loading

0 comments on commit 5ab0e70

Please sign in to comment.