[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -49,6 +49,7 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. |
|
||||
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
|
||||
| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||
| `layer_sharding` | dict | `{}` | Configuration options for layer sharding linear |
|
||||
|
||||
The details of each configuration option are as follows:
|
||||
|
||||
|
||||
BIN
docs/source/user_guide/feature_guide/images/layer_sharding.png
Normal file
BIN
docs/source/user_guide/feature_guide/images/layer_sharding.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 156 KiB |
@@ -19,6 +19,7 @@ external_dp
|
||||
large_scale_ep
|
||||
ucm_deployment
|
||||
Fine_grained_TP
|
||||
layer_sharding
|
||||
speculative_decoding
|
||||
context_parallel
|
||||
:::
|
||||
|
||||
73
docs/source/user_guide/feature_guide/layer_sharding.md
Normal file
73
docs/source/user_guide/feature_guide/layer_sharding.md
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: Layer Sharding Guide
|
||||
---
|
||||
|
||||
# Overview
|
||||
|
||||
**Layer Shard Linear** is a memory-optimization feature designed for large language model (LLM) inference. It addresses the high memory pressure caused by **repeated linear operators across many layers** that share identical structure but have distinct weights.
|
||||
|
||||
Instead of replicating all weights on every device, **Layer Shard Linear shards the weights of a "series" of such operators across the NPU devices in a communication group**:
|
||||
- The **i-th layer's linear weight** is stored **only on device `i % K`**, where `K` is the number of devices in the group.
|
||||
- Other devices hold a lightweight **shared dummy tensor** during initialization and fetch the real weight **on-demand via asynchronous broadcast** during the forward pass.
|
||||
|
||||
As illustrated in the figure below, this design enables broadcast to reach weights: while the current layer (e.g., MLA or MOE) is being computed, the system **asynchronously broadcasts the next layer's weight** in the background. Because the attention computation in the MLA module is sufficiently latency-bound, the weight transfer for `o_proj` is **fully overlapped with computation**, making the communication **latency-free from the perspective of end-to-end inference**.
|
||||
|
||||
This approach **preserves exact computational semantics** while **significantly reducing NPU memory footprint**, especially critical for:
|
||||
- Extremely deep architectures (e.g., DeepSeek-V3/R1 with 61 layers);
|
||||
- Models using **[DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702)** or **[FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188)**, where the full `O` (output) projection matrix must reside in memory per layer;
|
||||
- Scenarios where **attention computation latency fully overlaps** (hides) the communication cost of weight broadcasting.
|
||||
|
||||
---
|
||||
|
||||
## Flowchart
|
||||

|
||||
|
||||
> **Figure.** Layer Shard Linear workflow: weights are sharded by layer across devices (top), and during forward execution (bottom), asynchronous broadcast pre-fetches the next layer's weight while the current layer computes—enabling zero-overhead weight loading.
|
||||
|
||||
---
|
||||
|
||||
# Getting Started
|
||||
|
||||
To enable **Layer Shard Linear**, specify the target linear layers using the `--additional-config` argument when launching your inference job. For example, to shard the `o_proj` and `q_b_proj` layers, use:
|
||||
|
||||
```bash
|
||||
--additional-config '{
|
||||
"layer_sharding": ["o_proj", "q_b_proj"]
|
||||
}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Supported Scenarios
|
||||
|
||||
This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases:
|
||||
|
||||
## FlashComm2-enabled
|
||||
|
||||
When using [FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188), the full output projection (`o_proj`) matrix must be resident in memory for each layer. Layer sharding significantly reduces memory pressure by distributing these weights across devices.
|
||||
|
||||
**Example configuration:**
|
||||
|
||||
```bash
|
||||
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
|
||||
vllm serve \
|
||||
--model DeepSeek-V3/R1 \
|
||||
--additional-config '{
|
||||
"layer_sharding": ["o_proj"]
|
||||
}'
|
||||
```
|
||||
|
||||
## DSA-CP-enabled
|
||||
|
||||
With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory.
|
||||
|
||||
**Example configuration:**
|
||||
|
||||
```bash
|
||||
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
|
||||
vllm serve \
|
||||
--model DeepSeek-V3.2 \
|
||||
--additional-config '{
|
||||
"layer_sharding": ["q_b_proj", "o_proj"]
|
||||
}'
|
||||
```
|
||||
Reference in New Issue
Block a user