### What this PR does / why we need it?
1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.
### How was this patch tested?
Adding ut case in `tests/ut/attention/test_mla_v1.py`
#### How to run
use parameter `--additional_config='{"enable_shared_expert_dp": true}'`
##### a.How to run eager mode
eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'
##### b.How to run graph mode
eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'
- vLLM version: v0.10.0
- vLLM main:
9edd1db02b
---------
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
81 lines
4.0 KiB
Markdown
81 lines
4.0 KiB
Markdown
# Additional Configuration
|
|
|
|
additional configuration is a mechanism provided by vLLM to allow plugins to control inner behavior by their own. vLLM Ascend uses this mechanism to make the project more flexible.
|
|
|
|
## How to use
|
|
|
|
With either online mode or offline mode, users can use additional configuration. Take Qwen3 as an example:
|
|
|
|
**Online mode**:
|
|
|
|
```bash
|
|
vllm serve Qwen/Qwen3-8B --additional-config='{"config_key":"config_value"}'
|
|
```
|
|
|
|
**Offline mode**:
|
|
|
|
```python
|
|
from vllm import LLM
|
|
|
|
LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"})
|
|
```
|
|
|
|
### Configuration options
|
|
|
|
The following table lists the additional configuration options available in vLLM Ascend:
|
|
|
|
| Name | Type | Default | Description |
|
|
|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------|
|
|
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
|
|
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
|
|
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
|
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
|
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
|
|
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
|
| `enable_shared_expert_dp` | bool | `True` | When the shared expert in DP, it has better performance but consumes more memory. When the memory is sensitive, this switch can be turned off manually. |
|
|
|
|
The details of each config option are as follows:
|
|
|
|
**torchair_graph_config**
|
|
|
|
| Name | Type | Default | Description |
|
|
| ---- | ---- | ------- | ----------- |
|
|
| `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode |
|
|
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
|
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on DeepSeek moe models. |
|
|
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
|
|
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
|
|
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
|
|
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
|
|
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
|
|
|
**ascend_scheduler_config**
|
|
|
|
| Name | Type | Default | Description |
|
|
| ---- | ---- | ------- | ----------- |
|
|
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
|
|
|
|
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
|
|
|
|
### Example
|
|
|
|
An example of additional configuration is as follows:
|
|
|
|
```
|
|
{
|
|
"torchair_graph_config": {
|
|
"enabled": True,
|
|
"use_cached_graph": True,
|
|
"graph_batch_sizes": [1, 2, 4, 8],
|
|
"graph_batch_sizes_init": False,
|
|
"enable_multistream_moe": False,
|
|
"enable_kv_nz": False
|
|
},
|
|
"ascend_scheduler_config": {
|
|
"enabled": True,
|
|
"enable_chunked_prefill": True,
|
|
},
|
|
"refresh": False,
|
|
}
|
|
```
|