[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)

### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
  "layer_sharding": ["o_proj", "q_b_proj"]
}'
```

This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.



- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-08 09:05:02 +08:00
committed by GitHub
parent 20a8cf061b
commit f7db812ed7
13 changed files with 288 additions and 169 deletions

View File

@@ -25,11 +25,11 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
@@ -374,22 +374,17 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
self._replace_linear_class_for_sfa_cp()
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
if is_hidden_layer(self.vllm_config, self.q_proj):
register_layer_to_shared_weight_series(
series_name="q_proj",
group=get_shared_weight_group(),
layer=self.q_proj,
prefetch_step=1)
if is_hidden_layer(self.vllm_config, self.o_proj):
register_layer_to_shared_weight_series(
series_name="o_proj",
group=get_shared_weight_group(),
layer=self.o_proj,
prefetch_step=1)
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(
self.layer_sharding_kwargs)
# indexer param
self.n_head: int = self.indexer.n_head # 64
@@ -434,14 +429,10 @@ class AscendSFAImpl(MLAAttentionImpl):
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
dispose_layer(self.kv_b_proj)
if self.enable_sfa_cp:
if is_hidden_layer(self.vllm_config, self.q_proj):
post_process_after_loading_for_shared_weight_series(
self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
post_process_after_loading_for_shared_weight_series(
self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)
if self.enable_mlapo:
quant_method = getattr(
@@ -751,10 +742,9 @@ class AscendSFAImpl(MLAAttentionImpl):
if attn_metadata is None:
# Profiling run.
if self.enable_sfa_cp and not forward_context.in_profile_run:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)
has_prefill = attn_metadata.has_prefill
cos = attn_metadata.cos
@@ -809,10 +799,9 @@ class AscendSFAImpl(MLAAttentionImpl):
slot_mapping_cp)
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)