[Feat] Flashcomm2 use o_shared linear (#4188)

### What this PR does / why we need it?

It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.

We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.

This PR depends on #3232 and #2931

### Flashcomm2 flowchart
<img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39"
src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0"
/>

### Does this PR introduce _any_ user-facing change?

Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027 <zzh02232027@antgroup.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
zzhxxx
2025-12-11 12:43:04 +08:00
committed by GitHub
parent bb76f7962c
commit eac72f5f23
8 changed files with 86 additions and 25 deletions

View File

@@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import enable_sp, flashcomm2_enable
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
flashcomm2_o_shared_enabled)
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
@@ -77,6 +78,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
vllm_config = get_current_vllm_config()
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
@@ -182,6 +184,29 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="lmheadtp")
def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
group_ranks = []
for pp_idx in range(global_pp_size):
group = []
for dp_idx in range(global_dp_size):
base = (dp_idx * global_pp_size + pp_idx) * global_tp_size
for i in range(global_tp_size):
global_rank = base + i
group.append(global_rank)
group_ranks.append(group)
return init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name=group_name)
global _SHARED_WEIGHT
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
if enable_sp() and is_ds_v32:
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
# TODO: Extract and unify the logic across different communication group.
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
@@ -234,17 +259,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="flashcomm2_odp")
vllm_config = get_current_vllm_config()
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
if enable_sp() and is_ds_v32:
global _SHARED_WEIGHT
group_ranks = [list(range(torch.distributed.get_world_size()))]
_SHARED_WEIGHT = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="CP_shared_weight")
# Create shared weight group for flashcomm2 oproj
if flashcomm2_o_shared_enabled():
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
def get_mlp_tensor_model_parallel_world_size():