Skip to content

Commit

Permalink
[Doc]: Deepseek reference docs (#2787)
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaotongJiang authored Jan 9, 2025
1 parent 4f077c0 commit 11fffbc
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ The core features include:
references/troubleshooting.md
references/faq.md
references/learn_more.md
references/deepseek.md
34 changes: 34 additions & 0 deletions docs/references/deepseek.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# DeepSeek Model Optimizations in SGLang

SGLang provides several optimizations specifically designed for the DeepSeek model to boost its inference speed. This document outlines current optimizations for DeepSeek. Additionally, the SGLang team is actively developing enhancements for [DeepSeek-V3](https://github.com/sgl-project/sglang/issues/2591).


## Multi-head Latent Attention (MLA) Throughput Optimizations

**Description**: [MLA](https://arxiv.org/pdf/2405.04434) is an innovative attention mechanism introduced by the DeepSeek team, aimed at improving inference efficiency. SGLang has implemented specific optimizations for this, including:

- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **Triton Decoding Kernel Optimization**: In the MLA decoding kernel, there is only one KV head. This optimization reduces memory access to the KV cache by processing multiple query heads within one block, accelerating the decoding process.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
- **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes.

Overall, with these optimizations, we have achieved up to a 7x acceleration in output throughput compared to the previous version.
![Data Parallelism Attention for DeepSeek Series Models](https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg)

**Usage**: MLA optimization is enabled by defalut, to disable, use `--disable-mla`.

**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details.

## Data Parallelism Attention

**Description**: This optimization involves data parallelism (DP) for the MLA attention mechanism of DeepSeek Series Models, which allows for a significant reduction in the KV cache size, enabling larger batch sizes. Each DP worker independently handles different types of batches (prefill, decode, idle), which are then synchronized before and after processing through the Mixture-of-Experts (MoE) layer.
![Data Parallelism Attention for DeepSeek Series Models](https://lmsys.org/images/blog/sglang_v0_4/dp_attention.svg).

**Usage**: This optimization is aimed at improving throughput and should be used for scenarios with high QPS (Queries Per Second). Data Parallelism Attention optimization can be enabeld by `--enable-dp-attention` for DeepSeek Series Models.

**Reference**: Check [Blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models).

## Multi Node Tensor Parallelism
**Description**: For users with limited memory on a single node, SGLang supports serving DeepSeek Series Models, including DeepSeek V3, across multiple nodes using tensor parallelism. This approach partitions the model parameters across multiple GPUs or nodes to handle models that are too large for one node's memory.

**Usage**: Check [here](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-2-h208) for usage examples.
4 changes: 2 additions & 2 deletions docs/references/modelscope.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ To use a model from [ModelScope](https://www.modelscope.cn), set the environment
export SGLANG_USE_MODELSCOPE=true
```

We take [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) as an example. Launch the Server:
---
We take [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) as an example.

Launch the Server:
```bash
python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000
```
Expand Down

0 comments on commit 11fffbc

Please sign in to comment.