Docs: add torch compile cache (#4151)
Co-authored-by: ybyang <ybyang7@iflytek.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Hyperparameter Tuning
|
||||
|
||||
## Achieving Peak Throughput
|
||||
|
||||
Achieving a large batch size is the most important thing for attaining high throughput.
|
||||
|
||||
When the server is running at full load, look for the following in the log:
|
||||
@@ -8,11 +9,15 @@ When the server is running at full load, look for the following in the log:
|
||||
```Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 317```
|
||||
|
||||
### Tune Your Request Submission Speed
|
||||
|
||||
`#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed.
|
||||
|
||||
A healthy range for `#queue-req` is `50 - 500`.
|
||||
|
||||
On the other hand, do not make `#queue-req` too large because it will also increase the scheduling overhead on the server, especially when using the default longest-prefix-match schedule policy (`--schedule-policy lpm`).
|
||||
|
||||
### Tune `--schedule-conservativeness`
|
||||
|
||||
`token usage` indicates the KV cache memory utilization of the server. `token usage > 0.9` means good utilization.
|
||||
If you frequently see `token usage < 0.9` and `#queue-req > 0`, it means the server is too conservative about taking in new requests. You can decrease `--schedule-conservativeness` to a value like 0.3.
|
||||
The case of server being too conservative can happen when users send many requests with a large `max_new_tokens` but the requests stop very early due to EOS or stop strings.
|
||||
@@ -22,18 +27,37 @@ On the other hand, if you see `token usage` very high and you frequently see war
|
||||
If you see `decode out of memory happened` occasionally but not frequently, it is okay.
|
||||
|
||||
### Tune `--dp-size` and `--tp-size`
|
||||
Data parallelism is better for throughput. When there is enough GPU memory, always favor data parallelism for throughput.
|
||||
|
||||
### Avoid out-of-memory by Tuning `--chunked-prefill-size`, `--mem-fraction-static`, `--max-running-requests`
|
||||
Data parallelism is better for throughput. When there is enough GPU memory, always favor data parallelism for throughput. Refer to [sglang router](../backend/sglang_router.md) for a better data parallelism rather than using `dp_size` parameter.
|
||||
|
||||
## Avoid out-of-memory by Tuning `--chunked-prefill-size`, `--mem-fraction-static`, `--max-running-requests`
|
||||
|
||||
If you see out of memory (OOM) errors, you can try to tune the following parameters.
|
||||
- If OOM happens during prefill, try to decrease `--chunked-prefill-size` to `4096` or `2048`.
|
||||
- If OOM happens during decoding, try to decrease `--max-running-requests`.
|
||||
- You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
|
||||
|
||||
### Try Advanced Options
|
||||
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
|
||||
## Enabling cache for `torch.compile`
|
||||
|
||||
To enable `torch.compile` acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently. By default, `torch.compile` will automatically cache the FX graph and Triton in `/tmp/torchinductor_root`, which might be cleared according to the [system policy](https://serverfault.com/questions/377348/when-does-tmp-get-cleared). You can export the environment variable `TORCHINDUCTOR_CACHE_DIR` to save compilation cache in your desired directory to avoid unwanted deletion. You can also share the cache with other machines to reduce the compilation time.
|
||||
|
||||
SGLang uses `max-autotune-no-cudagraphs` mode of `torch.compile`. The auto-tuning can be slow.
|
||||
If you want to deploy a model on many different machines, you can ship the `torch.compile` cache to these machines and skip the compilation steps. This is based on [PyTorch official documentation](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html).
|
||||
|
||||
*Examples*:
|
||||
|
||||
1. Generate the cache by setting `TORCHINDUCTOR_CACHE_DIR` and running the model once.
|
||||
|
||||
```bash
|
||||
TORCHINDUCTOR_CACHE_DIR=/root/inductor_root_cache python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile
|
||||
```
|
||||
|
||||
2. Copy the cache folder to other machines and launch the server with `TORCHINDUCTOR_CACHE_DIR`.
|
||||
|
||||
|
||||
## Tune `--schedule-policy`
|
||||
|
||||
### Tune `--schedule-policy`
|
||||
If the workload has many shared prefixes, use the default `--schedule-policy lpm`. `lpm` stands for longest prefix match.
|
||||
|
||||
When you have no shared prefixes at all or you always send the requests with the shared prefixes together,
|
||||
you can try `--schedule-policy fcfs`. `fcfs` stands for first come first serve. `fcfs` has a lower scheduling overhead.
|
||||
|
||||
@@ -3,31 +3,39 @@
|
||||
## Common launch commands
|
||||
|
||||
- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2
|
||||
```
|
||||
|
||||
- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. We recommend [SGLang Router](../router/router.md) for data parallelism.
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2
|
||||
```
|
||||
|
||||
- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`.
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7
|
||||
```
|
||||
|
||||
- See [hyperparameter tuning](hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
||||
- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size.
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
|
||||
```
|
||||
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
|
||||
|
||||
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently. You can refer to [`Enabling cache for torch.compile`](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile) for more details.
|
||||
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports other [quantization strategies (INT8/FP8)](https://github.com/sgl-project/sglang/blob/v0.3.6/python/sglang/srt/server_args.py#L671) as well.
|
||||
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
|
||||
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
|
||||
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](custom_chat_template.md).
|
||||
|
||||
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`
|
||||
```
|
||||
|
||||
```bash
|
||||
# Node 0
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 0
|
||||
|
||||
@@ -49,15 +57,13 @@ Please consult the documentation below to learn more about the parameters you ma
|
||||
* `kv_cache_dtype`: Dtype of the kv cache, defaults to the `dtype`.
|
||||
* `context_length`: The number of tokens our model can process *including the input*. Note that extending the default might lead to strange behavior.
|
||||
* `device`: The device we put the model, defaults to `cuda`.
|
||||
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
|
||||
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template). **Make sure the correct** `chat_template` **is passed, or performance degradation may occur!!!!**
|
||||
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
|
||||
* `revision`: Adjust if a specific version of the model should be used.
|
||||
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
|
||||
* `json_model_override_args`: Override model config with the provided JSON.
|
||||
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Make sure the correct `chat_template` is passed, or performance degradation may occur.**
|
||||
|
||||
## Serving: HTTP & API
|
||||
|
||||
@@ -178,7 +184,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
||||
|
||||
* `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163).
|
||||
* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this.
|
||||
* `enable_torch_compile`: Torch compile the model. This is an experimental feature.
|
||||
* `enable_torch_compile`: Torch compile the model. Note that compiling a model takes a long time but have a great performance boost. The compiled model can also be [cached for future use](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile).
|
||||
* `torch_compile_max_bs`: The maximum batch size when using `torch_compile`.
|
||||
* `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics.
|
||||
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
|
||||
|
||||
@@ -4,4 +4,4 @@ Multi-Node Deployment
|
||||
:maxdepth: 1
|
||||
|
||||
multi_node.md
|
||||
k8s.md
|
||||
deploy_on_k8s.md
|
||||
|
||||
@@ -61,8 +61,7 @@ If you encounter errors when starting the server, ensure the weights have finish
|
||||
|
||||
### Caching `torch.compile`
|
||||
|
||||
The DeepSeek series have huge model weights, it takes some time to compile the model with `torch.compile` for the first time if you have added the flag `--enable-torch-compile`. By default, `torch.compile` will automatically cache the FX graph and Triton in `/tmp/torchinductor_root`, which might be cleared according to the [system policy](https://serverfault.com/questions/377348/when-does-tmp-get-cleared). You can export the environment variable `TORCHINDUCTOR_CACHE_DIR` to save compilation cache in your desired directory to avoid unwanted deletion. You can also share the cache with other machines to reduce the compilation time. You may refer to the [PyTorch official documentation](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) and [SGLang Documentation](./torch_compile_cache.md) for more details.
|
||||
|
||||
The DeepSeek series have huge model weights, it takes some time to compile the model with `torch.compile` for the first time if you have added the flag `--enable-torch-compile`. You can refer [here](https://docs.sglang.ai/backend/hyperparameter_tuning.html#try-advanced-options) to optimize the caching of compilation results, so that the cache can be used to speed up the next startup.
|
||||
### Launch with One node of 8 H200
|
||||
|
||||
Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#using-docker-recommended). **Note that Deepseek V3 is already in FP8. So we should not run it with any quantization arguments like `--quantization fp8 --kv-cache-dtype fp8_e5m2`.** Also, `--enable-dp-attention` can be useful to improve for Deepseek V3/R1's throughput. Please refer to [Data Parallelism Attention](https://docs.sglang.ai/references/deepseek.html#multi-head-latent-attention-mla-throughput-optimizations) for detail.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Kubernetes
|
||||
# Deploy On Kubernetes
|
||||
|
||||
This docs is for deploying a RoCE Network-Based SGLANG Two-Node Inference Service on a Kubernetes (K8S) Cluster.
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Enabling cache for torch.compile
|
||||
|
||||
SGLang uses `max-autotune-no-cudagraphs` mode of torch.compile. The auto-tuning can be slow.
|
||||
If you want to deploy a model on many different machines, you can ship the torch.compile cache to these machines and skip the compilation steps.
|
||||
|
||||
This is based on https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html
|
||||
|
||||
|
||||
1. Generate the cache by setting TORCHINDUCTOR_CACHE_DIR and running the model once.
|
||||
```
|
||||
TORCHINDUCTOR_CACHE_DIR=/root/inductor_root_cache python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile
|
||||
```
|
||||
2. Copy the cache folder to other machines and launch the server with `TORCHINDUCTOR_CACHE_DIR`.
|
||||
Reference in New Issue
Block a user