[Feat] Flashcomm2 use o_shared linear (#4188)

### What this PR does / why we need it?

It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.

We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.

This PR depends on #3232 and #2931

### Flashcomm2 flowchart
<img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39"
src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0"
/>

### Does this PR introduce _any_ user-facing change?

Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027 <zzh02232027@antgroup.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
zzhxxx
2025-12-11 12:43:04 +08:00
committed by GitHub
parent bb76f7962c
commit eac72f5f23
8 changed files with 86 additions and 25 deletions

View File

@@ -953,17 +953,22 @@ def flashcomm2_enable() -> bool:
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
vllm_config):
def flashcomm2_o_shared_enabled() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
if not flashcomm2_enable():
logger.debug("FLASHCOMM2 not enable.")
return flashcomm2_oproj_tp_size
flashcomm2_oproj_shared = False
logger.info("FLASHCOMM2 not enable.")
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
logger.info(
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}"
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
)
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
logger.warning_once(
@@ -990,8 +995,10 @@ def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
"FLASHCOMM2 primarily targets P-scenario deployments, "
"with additional support for hybrid deployment scenarios. "
"It is not applicable in D-scenario environments.")
if flashcomm2_oproj_shared:
logger.info("Enable FLASHCOMM2 with oproj_shared.")
return flashcomm2_oproj_tp_size
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: