[Feat] Flashcomm2 use o_shared linear (#4188)

### What this PR does / why we need it?

It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.

We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.

This PR depends on #3232 and #2931

### Flashcomm2 flowchart
<img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39"
src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0"
/>

### Does this PR introduce _any_ user-facing change?

Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027 <zzh02232027@antgroup.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
zzhxxx
2025-12-11 12:43:04 +08:00
committed by GitHub
parent bb76f7962c
commit eac72f5f23
8 changed files with 86 additions and 25 deletions

View File

@@ -165,9 +165,8 @@ class AscendConfig:
"Only support P node tp size lagger then D node tp size")
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
from vllm_ascend.utils import \
get_flashcomm2_oproj_tp_size_and_validate_config
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
from vllm_ascend.utils import get_flashcomm2_config_and_validate
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
self, vllm_config)
self.enable_npugraph_ex = additional_config.get(
"enable_npugraph_ex", False)