Skip to content

Commit

Permalink
[data/llm/docs] Initial draft of user guide for Data LLM APIs (ray-pr…
Browse files Browse the repository at this point in the history
…oject#50674)

## Why are these changes needed?

Adds user guide and link-ins for Ray Data documentation.

This is part of the ray-project#50639 thread of work.

This is based on ray-project#50494 

cc @comaniac @gvspraveen @kouroshHakha  

## Related issue number

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [ ] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: Richard Liaw <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
  • Loading branch information
richardliaw and comaniac authored Feb 19, 2025
1 parent ee2ed35 commit 6aa079a
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 1 deletion.
1 change: 1 addition & 0 deletions .vale/styles/config/vocabularies/Data/accept.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dtype
[Ii]ngest
[Ii]nqueue(s)?
[Ll]ookup(s)?
LLM(s)?
Modin
[Mm]ultiget(s)?
ndarray(s)?
Expand Down
52 changes: 52 additions & 0 deletions doc/source/data/batch_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Offline batch inference is a process for generating model predictions on a fixed
:width: 650px
:align: center

.. note::
This guide is primarily focused on batch inference with deep learning frameworks.
For more information on batch inference with LLMs, see :ref:`Working with LLMs <working-with-llms>`.

.. _batch_inference_quickstart:

Expand Down Expand Up @@ -177,6 +180,55 @@ For how to configure batch inference, see :ref:`the configuration guide<batch_in

{'output': array([0.625576], dtype=float32)}

.. tab-item:: LLM Inference
:sync: vLLM

Ray Data offers native integration with vLLM, a high-performance inference engine for large language models (LLMs).

.. testcode::
:skipif: True

import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
import numpy as np

config = vLLMEngineProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
},
concurrency=1,
batch_size=64,
)
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a bot that responds with haikus."},
{"role": "user", "content": row["item"]}
],
sampling_params=dict(
temperature=0.3,
max_tokens=250,
)
),
postprocess=lambda row: dict(
answer=row["generated_text"]
),
)

ds = ray.data.from_items(["Start of the haiku is: Complete this for me..."])

ds = processor(ds)
ds.show(limit=1)

.. testoutput::
:options: +MOCK

{'answer': 'Snowflakes gently fall\nBlanketing the winter scene\nFrozen peaceful hush'}

.. _batch_inference_configuration:

Configuration and troubleshooting
Expand Down
1 change: 1 addition & 0 deletions doc/source/data/user-guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ show you how achieve several tasks.
working-with-text
working-with-tensors
working-with-pytorch
working-with-llms
monitoring-your-workload
execution-configurations
batch_inference
Expand Down
150 changes: 150 additions & 0 deletions doc/source/data/working-with-llms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
.. _working-with-llms:

Working with LLMs
=================

The `ray.data.llm` module integrates with key large language model (LLM) inference engines and deployed models to enable LLM batch inference.

This guide shows you how to use `ray.data.llm` to:

* :ref:`Perform batch inference with LLMs <batch_inference_llm>`
* :ref:`Configure vLLM for LLM inference <vllm_llm>`
* :ref:`Query deployed models with an OpenAI compatible API endpoint <openai_compatible_api_endpoint>`

.. _batch_inference_llm:

Perform batch inference with LLMs
---------------------------------

At a high level, the `ray.data.llm` module provides a `Processor` object which encapsulates
logic for performing batch inference with LLMs on a Ray Data dataset.

You can use the `build_llm_processor` API to construct a processor. In the following example, we use the `vLLMProcessorConfig` to construct a processor for the `meta-llama/Llama-3.1-8B-Instruct` model.

The vLLMProcessorConfig is a configuration object for the vLLM engine.
It contains the model name, the number of GPUs to use, and the number of shards to use, along with other vLLM engine configurations. Upon execution, the Processor object instantiates replicas of the vLLM engine (using `map_batches` under the hood).

.. testcode::

import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
import numpy as np

config = vLLMEngineProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
},
concurrency=1,
batch_size=64,
)
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a bot that responds with haikus."},
{"role": "user", "content": row["item"]}
],
sampling_params=dict(
temperature=0.3,
max_tokens=250,
)
),
postprocess=lambda row: dict(
answer=row["generated_text"]
),
)

ds = ray.data.from_items(["Start of the haiku is: Complete this for me..."])

ds = processor(ds)
ds.show(limit=1)

.. testoutput::
:options: +MOCK

{'answer': 'Snowflakes gently fall\nBlanketing the winter scene\nFrozen peaceful hush'}

.. _vllm_llm:

Configure vLLM for LLM inference
--------------------------------

Use the `vLLMProcessorConfig` to configure the vLLM engine.

.. testcode::

from ray.data.llm import vLLMProcessorConfig

processor_config = vLLMProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={"max_model_len": 20000},
concurrency=1,
batch_size=64,
)

For handling larger models, specify model parallelism.

.. testcode::

processor_config = vLLMProcessorConfig(
model="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 16384,
"tensor_parallel_size": 2,
"pipeline_parallel_size": 2,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 2048,
},
concurrency=1,
batch_size=64,
)

The underlying `Processor` object instantiates replicas of the vLLM engine and automatically
configure parallel workers to handle model parallelism (for tensor parallelism and pipeline parallelism,
if specified).


.. _openai_compatible_api_endpoint:

OpenAI Compatible API Endpoint
------------------------------

You can also make calls to deployed models that have an OpenAI compatible API endpoint.

.. testcode::

import ray
import os
from ray.data.llm import HttpRequestProcessorConfig, build_llm_processor

OPENAI_KEY = os.environ["OPENAI_API_KEY"]
ds = ray.data.from_items(["Hand me a haiku."])


config = HttpRequestProcessorConfig(
url="https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {OPENAI_KEY}"},
qps=1,
)

processor = build_llm_processor(
config,
preprocess=lambda row: dict(
payload=dict(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a bot that responds with haikus."},
{"role": "user", "content": row["item"]}
],
temperature=0.0,
max_tokens=150,
),
),
postprocess=lambda row: dict(response=row["http_response"]["choices"][0]["message"]["content"]),
)

ds = processor(ds)
print(ds.take_all())
2 changes: 2 additions & 0 deletions doc/source/data/working-with-text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ that sets up and invokes a model. Then, call
{'text': 'Beautiful is better than ugly.', 'label': 'POSITIVE'}
{'text': 'Explicit is better than implicit.', 'label': 'POSITIVE'}

For more information on handling large language models, see :ref:`Working with LLMs <working-with-llms>`.

For more information on performing inference, see
:ref:`End-to-end: Offline Batch Inference <batch_inference_home>`
and :ref:`Stateful Transforms <stateful_transforms>`.
Expand Down
69 changes: 68 additions & 1 deletion python/ray/data/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class ProcessorConfig(_ProcessorConfig):
class HttpRequestProcessorConfig(_HttpRequestProcessorConfig):
"""The configuration for the HTTP request processor.
Args:
batch_size: The batch size to send to the HTTP request.
url: The URL to send the HTTP request to.
headers: The headers to send with the HTTP request.
concurrency: The number of concurrent requests to send.
Examples:
.. testcode::
:skipif: True
Expand Down Expand Up @@ -61,7 +67,28 @@ class HttpRequestProcessorConfig(_HttpRequestProcessorConfig):
class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
"""The configuration for the vLLM engine processor.
Args:
model: The model to use for the vLLM engine.
engine_kwargs: The kwargs to pass to the vLLM engine.
task_type: The task type to use. If not specified, will use 'generate' by default.
runtime_env: The runtime environment to use for the vLLM engine.
max_pending_requests: The maximum number of pending requests. If not specified,
will use the default value from the vLLM engine.
max_concurrent_batches: The maximum number of concurrent batches in the engine.
This is to overlap the batch processing to avoid the tail latency of
each batch. The default value may not be optimal when the batch size
or the batch processing latency is too small, but it should be good
enough for batch size >= 64.
apply_chat_template: Whether to apply chat template.
chat_template: The chat template to use. This is usually not needed if the
model checkpoint already contains the chat template.
tokenize: Whether to tokenize the input before passing it to the vLLM engine.
If not, vLLM will tokenize the prompt in the engine.
detokenize: Whether to detokenize the output.
has_image: Whether the input messages have images.
Examples:
.. testcode::
:skipif: True
Expand All @@ -75,7 +102,6 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
enable_chunked_prefill=True,
max_num_batched_tokens=4096,
),
accelerator_type="L4",
concurrency=1,
batch_size=64,
)
Expand Down Expand Up @@ -124,6 +150,47 @@ def build_llm_processor(
Returns:
The built processor.
Example:
.. testcode::
:skipif: True
import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
config = vLLMEngineProcessorConfig(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
engine_kwargs=dict(
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_num_batched_tokens=4096,
),
concurrency=1,
batch_size=64,
)
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a calculator"},
{"role": "user", "content": f"{row['id']} ** 3 = ?"},
],
sampling_params=dict(
temperature=0.3,
max_tokens=20,
detokenize=False,
),
),
postprocess=lambda row: dict(
resp=row["generated_text"],
),
)
ds = ray.data.range(300)
ds = processor(ds)
for row in ds.take_all():
print(row)
"""
from ray.llm._internal.batch.processor import ProcessorBuilder

Expand Down

0 comments on commit 6aa079a

Please sign in to comment.