[Feat] Flashcomm2 use o_shared linear (#4188)

### What this PR does / why we need it?

It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.

We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.

This PR depends on #3232 and #2931

### Flashcomm2 flowchart
<img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39"
src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0"
/>

### Does this PR introduce _any_ user-facing change?

Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027 <zzh02232027@antgroup.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
zzhxxx
2025-12-11 12:43:04 +08:00
committed by GitHub
parent bb76f7962c
commit eac72f5f23
8 changed files with 86 additions and 25 deletions

View File

@@ -34,10 +34,15 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz, weak_ref_tensors)
flashcomm2_o_shared_enabled, is_enable_nz,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@@ -848,6 +853,19 @@ class AscendMLAImpl(MLAAttentionImpl):
'q_b_proj']
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.vllm_config = get_current_vllm_config()
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
register_layer_to_shared_weight_series(
series_name="o_proj",
group=get_shared_weight_group(),
layer=self.o_proj,
prefetch_step=1)
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
@@ -858,10 +876,9 @@ class AscendMLAImpl(MLAAttentionImpl):
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
self.speculative_config = vllm_config.speculative_config
self.speculative_config = self.vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
self.pcp_size = get_pcp_group().world_size
@@ -995,6 +1012,10 @@ class AscendMLAImpl(MLAAttentionImpl):
if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype)
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
post_process_after_loading_for_shared_weight_series(self.o_proj)
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., self.q_lora_rank:].contiguous()
@@ -1515,6 +1536,10 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
@@ -1633,6 +1658,9 @@ class AscendMLAImpl(MLAAttentionImpl):
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0)
if self.pcp_size > 1:
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size