[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)

### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
  "layer_sharding": ["o_proj", "q_b_proj"]
}'
```

This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.



- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-08 09:05:02 +08:00
committed by GitHub
parent 20a8cf061b
commit f7db812ed7
13 changed files with 288 additions and 169 deletions

View File

@@ -31,15 +31,14 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces)
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND,
flashcomm2_o_shared_enabled, maybe_trans_nz,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
@@ -734,18 +733,6 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.vllm_config = get_current_vllm_config()
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
register_layer_to_shared_weight_series(
series_name="o_proj",
group=get_shared_weight_group(),
layer=self.o_proj,
prefetch_step=1)
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
@@ -762,6 +749,15 @@ class AscendMLAImpl(MLAAttentionImpl):
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
def _v_up_proj(self, x):
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
@@ -833,9 +829,9 @@ class AscendMLAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
post_process_after_loading_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
@@ -1445,9 +1441,9 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
decode_preprocess_res = None
prefill_preprocess_res = None
@@ -1478,9 +1474,9 @@ class AscendMLAImpl(MLAAttentionImpl):
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)
forward_context = get_forward_context()