[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)

### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
  "layer_sharding": ["o_proj", "q_b_proj"]
}'
```

This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.



- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-08 09:05:02 +08:00
committed by GitHub
parent 20a8cf061b
commit f7db812ed7
13 changed files with 288 additions and 169 deletions

View File

@@ -49,6 +49,7 @@ The following table lists additional configuration options available in vLLM Asc
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. |
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `layer_sharding` | dict | `{}` | Configuration options for layer sharding linear |
The details of each configuration option are as follows:

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

View File

@@ -19,6 +19,7 @@ external_dp
large_scale_ep
ucm_deployment
Fine_grained_TP
layer_sharding
speculative_decoding
context_parallel
:::

View File

@@ -0,0 +1,73 @@
---
title: Layer Sharding Guide
---
# Overview
**Layer Shard Linear** is a memory-optimization feature designed for large language model (LLM) inference. It addresses the high memory pressure caused by **repeated linear operators across many layers** that share identical structure but have distinct weights.
Instead of replicating all weights on every device, **Layer Shard Linear shards the weights of a "series" of such operators across the NPU devices in a communication group**:
- The **i-th layer's linear weight** is stored **only on device `i % K`**, where `K` is the number of devices in the group.
- Other devices hold a lightweight **shared dummy tensor** during initialization and fetch the real weight **on-demand via asynchronous broadcast** during the forward pass.
As illustrated in the figure below, this design enables broadcast to reach weights: while the current layer (e.g., MLA or MOE) is being computed, the system **asynchronously broadcasts the next layer's weight** in the background. Because the attention computation in the MLA module is sufficiently latency-bound, the weight transfer for `o_proj` is **fully overlapped with computation**, making the communication **latency-free from the perspective of end-to-end inference**.
This approach **preserves exact computational semantics** while **significantly reducing NPU memory footprint**, especially critical for:
- Extremely deep architectures (e.g., DeepSeek-V3/R1 with 61 layers);
- Models using **[DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702)** or **[FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188)**, where the full `O` (output) projection matrix must reside in memory per layer;
- Scenarios where **attention computation latency fully overlaps** (hides) the communication cost of weight broadcasting.
---
## Flowchart
![layer shard](./images/layer_sharding.png)
> **Figure.** Layer Shard Linear workflow: weights are sharded by layer across devices (top), and during forward execution (bottom), asynchronous broadcast pre-fetches the next layer's weight while the current layer computes—enabling zero-overhead weight loading.
---
# Getting Started
To enable **Layer Shard Linear**, specify the target linear layers using the `--additional-config` argument when launching your inference job. For example, to shard the `o_proj` and `q_b_proj` layers, use:
```bash
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
---
# Supported Scenarios
This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases:
## FlashComm2-enabled
When using [FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188), the full output projection (`o_proj`) matrix must be resident in memory for each layer. Layer sharding significantly reduces memory pressure by distributing these weights across devices.
**Example configuration:**
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
vllm serve \
--model DeepSeek-V3/R1 \
--additional-config '{
"layer_sharding": ["o_proj"]
}'
```
## DSA-CP-enabled
With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory.
**Example configuration:**
```bash
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
vllm serve \
--model DeepSeek-V3.2 \
--additional-config '{
"layer_sharding": ["q_b_proj", "o_proj"]
}'
```

View File

@@ -25,14 +25,10 @@ def mock_distributed():
patch('torch.distributed.get_world_size', return_value=16), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \
patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group:
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
mock_tp_group.return_value.world_size = 4
mock_dp_group.return_value.world_size = 2
mock_pp_group.return_value.world_size = 2
yield
@@ -50,7 +46,6 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_vllm_config.kv_transfer_config.is_kv_producer = True
mock_envs_ascend = MagicMock()
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \

View File

@@ -51,6 +51,12 @@ class AscendConfig:
"weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(
weight_prefetch_config)
self.layer_sharding = additional_config.get("layer_sharding", None)
logger.info_once(
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
"Note: This feature works optimally with FLASHCOMM2 and DSA-CP enabled; "
"using it without these features may result in significant performance degradation."
)
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
self.expert_map_path = additional_config.get("expert_map_path", None)
@@ -111,7 +117,7 @@ class AscendConfig:
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
from vllm_ascend.utils import get_flashcomm2_config_and_validate
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(
self, vllm_config)
self.enable_npugraph_ex = additional_config.get(
"enable_npugraph_ex", False)

View File

@@ -31,15 +31,14 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces)
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND,
flashcomm2_o_shared_enabled, maybe_trans_nz,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
@@ -734,18 +733,6 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.vllm_config = get_current_vllm_config()
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
register_layer_to_shared_weight_series(
series_name="o_proj",
group=get_shared_weight_group(),
layer=self.o_proj,
prefetch_step=1)
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
@@ -762,6 +749,15 @@ class AscendMLAImpl(MLAAttentionImpl):
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
def _v_up_proj(self, x):
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
@@ -833,9 +829,9 @@ class AscendMLAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
post_process_after_loading_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
@@ -1445,9 +1441,9 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
decode_preprocess_res = None
prefill_preprocess_res = None
@@ -1478,9 +1474,9 @@ class AscendMLAImpl(MLAAttentionImpl):
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)
forward_context = get_forward_context()

View File

@@ -25,11 +25,11 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
@@ -374,22 +374,17 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
self._replace_linear_class_for_sfa_cp()
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
if is_hidden_layer(self.vllm_config, self.q_proj):
register_layer_to_shared_weight_series(
series_name="q_proj",
group=get_shared_weight_group(),
layer=self.q_proj,
prefetch_step=1)
if is_hidden_layer(self.vllm_config, self.o_proj):
register_layer_to_shared_weight_series(
series_name="o_proj",
group=get_shared_weight_group(),
layer=self.o_proj,
prefetch_step=1)
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(
self.layer_sharding_kwargs)
# indexer param
self.n_head: int = self.indexer.n_head # 64
@@ -434,14 +429,10 @@ class AscendSFAImpl(MLAAttentionImpl):
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
dispose_layer(self.kv_b_proj)
if self.enable_sfa_cp:
if is_hidden_layer(self.vllm_config, self.q_proj):
post_process_after_loading_for_shared_weight_series(
self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
post_process_after_loading_for_shared_weight_series(
self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)
if self.enable_mlapo:
quant_method = getattr(
@@ -751,10 +742,9 @@ class AscendSFAImpl(MLAAttentionImpl):
if attn_metadata is None:
# Profiling run.
if self.enable_sfa_cp and not forward_context.in_profile_run:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)
has_prefill = attn_metadata.has_prefill
cos = attn_metadata.cos
@@ -809,10 +799,9 @@ class AscendSFAImpl(MLAAttentionImpl):
slot_mapping_cp)
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)

View File

@@ -2,14 +2,12 @@ from typing import Optional
import torch
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
get_pp_group, get_tp_group,
from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group,
get_world_group,
init_model_parallel_group)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
flashcomm2_o_shared_enabled)
from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
@@ -25,8 +23,8 @@ _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
_FC3_QUANT_X: Optional[GroupCoordinator] = None
# shared_weight across rank groups
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
# shard_weight across rank groups
_SHARD_WEIGHT: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
@@ -37,7 +35,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
vllm_config = get_current_vllm_config()
global_tp_size = parallel_config.tensor_parallel_size
global_dp_size = parallel_config.data_parallel_size
global_pp_size = parallel_config.pipeline_parallel_size
@@ -48,6 +45,14 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
all_ranks = torch.arange(world_size).reshape(
-1, global_dp_size * parallel_config.prefill_context_parallel_size *
global_tp_size)
#TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
vllm_all_ranks = torch.arange(world_size).reshape(
-1,
global_dp_size,
global_pp_size,
parallel_config.prefill_context_parallel_size,
global_tp_size,
)
pd_tp_ratio = get_ascend_config().pd_tp_ratio
pd_head_ratio = get_ascend_config().pd_head_ratio
@@ -148,38 +153,13 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if mlp_tp_size > 0:
_MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp")
def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
group_ranks = []
for pp_idx in range(global_pp_size):
group = []
for dp_idx in range(global_dp_size):
base = (dp_idx * global_pp_size + pp_idx) * global_tp_size
for i in range(global_tp_size):
global_rank = base + i
group.append(global_rank)
group_ranks.append(group)
return init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name=group_name)
global _SHARED_WEIGHT
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
is_ds_v32 = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None:
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
# TODO: Extract and unify the logic across different communication group.
flashcomm2_otp_group_ranks = []
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size
global_tp_size = get_tp_group().world_size
global_dp_size = get_dp_group().world_size
global_pp_size = get_pp_group().world_size
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)
global _FLASHCOMM2_OTP
global _FLASHCOMM2_ODP
@@ -187,7 +167,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
_FLASHCOMM2_ODP = get_tp_group()
if flashcomm2_otp_size > 1:
otp_group_ranks = []
odp_group_ranks: list[list[int]] = [
[] for _ in range(flashcomm2_otp_size * global_dp_size *
global_pp_size)
@@ -209,10 +188,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
odp_group_index = odp_base_index + j
odp_group_ranks[odp_group_index].append(
global_rank)
otp_group_ranks.append(ranks)
flashcomm2_otp_group_ranks.append(ranks)
_FLASHCOMM2_OTP = init_model_parallel_group(
otp_group_ranks,
flashcomm2_otp_group_ranks,
get_world_group().local_rank,
backend,
group_name="flashcomm2_otp")
@@ -222,12 +201,50 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="flashcomm2_odp")
# Create shared weight group for flashcomm2 oproj
if flashcomm2_o_shared_enabled():
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
if _SHARED_WEIGHT is None:
_SHARED_WEIGHT = _create_shared_weight_group(
"flashcomm2_o_shared")
def create_shard_weight_group(
module_tp_group_ranks: None) -> GroupCoordinator:
# Argument module_tp_group_ranks: The module specific tensor parallel group.
# There are three situations.
# 1. If it is None, then the TP_size of the specific module is 1 and is replicated linear layer.
# 2. If it is not None, and the module tp_group is same as the global tp_group.
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
group_ranks = []
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(
-1, global_pp_size)
if module_tp_group_ranks is None:
# If it is None, then the TP_size of this shard weight is 1.
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
group_ranks = [x.tolist() for x in shard_weight_group_ranks]
else:
# combine standard tp group and non-standard tp group to build shard_weight comm_group
module_tp_tanspose_ranks = module_tp_group_ranks.transpose(0, 1)
G = world_size // (global_pp_size * module_tp_group_ranks.size(1))
shard_weight_group_ranks = torch.stack(
[t.view(global_pp_size, G) for t in module_tp_tanspose_ranks],
dim=1)
group_ranks = shard_weight_group_ranks.view(-1, G).tolist()
return init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="shard_weight")
# Create shard weight group if enabled
if get_ascend_config().layer_sharding is not None:
global _SHARD_WEIGHT
if flashcomm2_enable():
if len(flashcomm2_otp_group_ranks) == 0:
FC2_group_ranks = None
else:
FC2_group_ranks = torch.tensor(
flashcomm2_otp_group_ranks).squeeze(0)
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
elif enable_dsa_cp():
# For dsa_cp, all shard layers are replicated.
_SHARD_WEIGHT = create_shard_weight_group(None)
else:
# For standard tp, use global tp group_ranks
tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size)
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X
@@ -280,11 +297,10 @@ def get_flashcomm2_odp_group() -> GroupCoordinator:
return _FLASHCOMM2_ODP
def get_shared_weight_group() -> GroupCoordinator:
assert _SHARED_WEIGHT is not None, (
"output shared weight parallel group for flashcomm2 is not initialized"
)
return _SHARED_WEIGHT
def get_shard_weight_group() -> GroupCoordinator:
assert _SHARD_WEIGHT is not None, (
"output shard weight parallel group for flashcomm2 is not initialized")
return _SHARD_WEIGHT
def get_p_tp_group() -> GroupCoordinator:
@@ -341,10 +357,10 @@ def destroy_ascend_model_parallel():
_FLASHCOMM2_ODP.destroy()
_FLASHCOMM2_ODP = None
global _SHARED_WEIGHT
if _SHARED_WEIGHT:
_SHARED_WEIGHT.destroy()
_SHARED_WEIGHT = None
global _SHARD_WEIGHT
if _SHARD_WEIGHT:
_SHARD_WEIGHT.destroy()
_SHARD_WEIGHT = None
global _FC3_QUANT_X
if _FC3_QUANT_X:

View File

@@ -2,10 +2,10 @@ import os
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import (get_dp_group,
get_fc3_quant_x_group,
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
get_p_tp_group)

View File

@@ -92,11 +92,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
# This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature,
# which can eliminate redundant storage of weights. More detailed information can be found in PR#4188.
# We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited.
"VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Callable, Optional
from functools import lru_cache
from typing import Callable, List, Optional
import torch
import torch.distributed as dist
@@ -7,6 +8,8 @@ from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.models.utils import extract_layer_index
from vllm_ascend.distributed.parallel_state import get_shard_weight_group
def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
@@ -26,17 +29,17 @@ class LayerMetadata:
@dataclass
class SharedWindowMetadata:
"""Metadata for a shared window.
class ShardWindowMetadata:
"""Metadata for a shard window.
"""
weight: torch.Tensor # The weight tensor to be shared by layers.
weight: torch.Tensor # The weight tensor to be shard by layers.
data_layer_idx: int # The index of the layer this window's weight is equal to.
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
@dataclass
class SeriesMetadata:
"""Metadata for a weight shared series.
"""Metadata for a weight shard series.
"""
group: GroupCoordinator
start_layer: int
@@ -45,8 +48,8 @@ class SeriesMetadata:
prefetch_step: int
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
layers: list[LayerMetadata]
shared_windows: list[
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
shard_windows: list[
ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
window_offset: int # The index of the window for the next coming layer.
def is_source(self, layer_idx) -> bool:
@@ -54,7 +57,7 @@ class SeriesMetadata:
def post_process_after_loading(self):
# This method only needs to be called once per series.
if self.shared_windows:
if self.shard_windows:
return
self.layers.sort(key=lambda x: x.layer_idx)
@@ -83,8 +86,8 @@ class SeriesMetadata:
step = layer_idx - self.start_layer
if step < self.prefetch_step:
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
self.shared_windows.append(
SharedWindowMetadata(
self.shard_windows.append(
ShardWindowMetadata(
weight=layer.weight.clone().detach(),
data_layer_idx=layer_idx,
work=None,
@@ -92,12 +95,12 @@ class SeriesMetadata:
layer.window_idx = step
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
if not is_source:
layer.weight.set_(self.shared_windows[-1].weight)
layer.weight.set_(self.shard_windows[-1].weight)
else:
# Build one more window for prefetch. The weight is useless, so just keep the shape.
if step == self.prefetch_step:
self.shared_windows.append(
SharedWindowMetadata(
self.shard_windows.append(
ShardWindowMetadata(
weight=torch.empty_like(layer.weight),
data_layer_idx=-1,
work=None,
@@ -115,7 +118,7 @@ class SeriesMetadata:
next_layer = self.layers[next_layer_idx - self.start_layer]
# The index of the window to store the weight for the coming layer.
next_layer.window_idx = self.window_offset
window = self.shared_windows[next_layer.window_idx]
window = self.shard_windows[next_layer.window_idx]
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
if not self.is_source(next_layer_idx):
next_layer.weight.set_(window.weight)
@@ -133,10 +136,10 @@ class SeriesMetadata:
def wait_weight(self, layer_idx: int):
# Find the asynchronous broadcast work and wait for it.
assert self.shared_windows
window = self.shared_windows[self.layers[layer_idx -
self.start_layer].window_idx]
# Make sure the data in the corresponding shared window is for the current layer.
assert self.shard_windows
window = self.shard_windows[self.layers[layer_idx -
self.start_layer].window_idx]
# Make sure the data in the corresponding shard window is for the current layer.
assert window.data_layer_idx == layer_idx
if window.work is not None:
window.work.wait()
@@ -168,13 +171,13 @@ def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
"""
Register linear layers into a shared storage series.
Register linear layers into a shard storage series.
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)` on any layer of this series to complete the initialization.
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
- start_layer is the index of the first layer in the series (inclusive).
@@ -182,7 +185,7 @@ Note: The layers are managed as a circular buffer. The index of the layer to pre
- total_layers = end_layer - start_layer
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers will be created for this series.
Arguments:
series_name: This name identifies which series this layer belongs to.
@@ -192,7 +195,7 @@ Arguments:
"""
def register_layer_to_shared_weight_series(
def register_layer_to_shard_weight_series(
series_name: str,
group: GroupCoordinator,
layer: LinearBase,
@@ -208,7 +211,7 @@ def register_layer_to_shared_weight_series(
prefetch_step=prefetch_step,
dummy_weight=torch.empty_like(layer.weight),
layers=[],
shared_windows=[],
shard_windows=[],
window_offset=prefetch_step,
)
series = _series_dict[series_name]
@@ -236,17 +239,42 @@ def register_layer_to_shared_weight_series(
)
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
def post_process_after_loading_for_shard_weight_series(layer: LinearBase):
ext = _layer_external_dict[id(layer)]
ext.series.post_process_after_loading()
def reach_layer_for_shared_weight_series(layer: LinearBase):
def reach_layer_for_shard_weight_series(layer: LinearBase):
ext = _layer_external_dict[id(layer)]
ext.series.reach_layer(ext.layer_idx)
def is_hidden_layer(vllm_config, layer: LinearBase) -> bool:
num_hidden_layers = vllm_config.model_config.hf_text_config.num_hidden_layers
def wait_layer_for_shard_weight_series(layer: LinearBase):
ext = _layer_external_dict[id(layer)]
ext.series.wait_weight(ext.layer_idx)
@lru_cache(maxsize=1)
def get_current_model_num_hidden_layers() -> int:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.model_config.get_total_num_hidden_layers()
def is_hidden_layer(layer: LinearBase) -> bool:
num_hidden_layers = get_current_model_num_hidden_layers()
layer_idx = extract_layer_index(layer.prefix)
return layer_idx < num_hidden_layers
def register_all_layers_to_shard_weight_series(
layer_sharding: List[LinearBase], ):
for curr_layer in (layer_sharding or []):
if is_hidden_layer(curr_layer):
layer_name = curr_layer.prefix.split('.')[-1]
register_layer_to_shard_weight_series(
series_name=layer_name,
group=get_shard_weight_group(),
layer=curr_layer,
prefetch_step=1,
)

View File

@@ -23,6 +23,7 @@ import math
import os
from contextlib import contextmanager, nullcontext
from enum import Enum
from functools import lru_cache
from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
@@ -1016,22 +1017,27 @@ def flashcomm2_enable() -> bool:
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
def flashcomm2_o_shared_enabled() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
if not flashcomm2_enable():
flashcomm2_oproj_shared = False
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
return 0
logger.info(
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
)
layer_sharding = ascend_config.layer_sharding or []
if layer_sharding:
if layer_sharding == ["o_proj"]:
logger.info_once(
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
)
else:
raise ValueError(
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
f"Found invalid layer_sharding: {layer_sharding}")
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
logger.warning_once(
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
@@ -1054,13 +1060,10 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
)
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"FLASHCOMM2 primarily targets P-scenario deployments, "
"with additional support for hybrid deployment scenarios. "
"It is not applicable in D-scenario environments.")
if flashcomm2_oproj_shared:
logger.info("Enable FLASHCOMM2 with oproj_shared.")
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments."
)
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
return flashcomm2_oproj_tp_size
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
@@ -1160,4 +1163,20 @@ def singleton(cls):
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
return get_instance
@lru_cache(maxsize=1)
def get_current_model_config():
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.model_config
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
@lru_cache(maxsize=1)
def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
return is_ds_v32 and enable_sp()