[Feat] Flashcomm2 use o_shared linear (#4188)
### What this PR does / why we need it?
It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.
We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.
This PR depends on #3232 and #2931
### Flashcomm2 flowchart
<img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39"
src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0"
/>
### Does this PR introduce _any_ user-facing change?
Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027 <zzh02232027@antgroup.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -953,17 +953,22 @@ def flashcomm2_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
||||
|
||||
|
||||
def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
|
||||
vllm_config):
|
||||
def flashcomm2_o_shared_enabled() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
|
||||
|
||||
|
||||
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
||||
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
|
||||
|
||||
if not flashcomm2_enable():
|
||||
logger.debug("FLASHCOMM2 not enable.")
|
||||
return flashcomm2_oproj_tp_size
|
||||
flashcomm2_oproj_shared = False
|
||||
logger.info("FLASHCOMM2 not enable.")
|
||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
||||
|
||||
logger.info(
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}"
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
|
||||
)
|
||||
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
||||
logger.warning_once(
|
||||
@@ -990,8 +995,10 @@ def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, "
|
||||
"with additional support for hybrid deployment scenarios. "
|
||||
"It is not applicable in D-scenario environments.")
|
||||
if flashcomm2_oproj_shared:
|
||||
logger.info("Enable FLASHCOMM2 with oproj_shared.")
|
||||
|
||||
return flashcomm2_oproj_tp_size
|
||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
||||
|
||||
|
||||
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
||||
|
||||
Reference in New Issue
Block a user