[Feat] flashcomm_v2 optim solution (#3232)

### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
This commit is contained in:
Levi
2025-11-10 11:01:45 +08:00
committed by GitHub
parent b1a00e0512
commit 0a62e671fb
12 changed files with 380 additions and 24 deletions

View File

@@ -2,12 +2,14 @@ from typing import Optional
import torch
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
get_tp_group, get_world_group,
init_model_parallel_group)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import prefill_context_parallel_enable
from vllm_ascend.utils import (flashcomm2_enable,
prefill_context_parallel_enable)
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
@@ -15,6 +17,8 @@ _MLP_TP: Optional[GroupCoordinator] = None
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
@@ -34,6 +38,16 @@ def get_lmhead_tp_group() -> GroupCoordinator:
return _LMTP
def get_flashcomm2_otp_group() -> GroupCoordinator:
return _FLASHCOMM2_OTP
def get_flashcomm2_odp_group() -> GroupCoordinator:
assert _FLASHCOMM2_ODP is not None, (
"output data parallel group for flashcomm2 is not initialized")
return _FLASHCOMM2_ODP
def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
@@ -165,6 +179,48 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="lmheadtp")
# TODO: Extract and unify the logic across different communication group.
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size
global_tp_size = get_tp_group().world_size
global_dp_size = get_dp_group().world_size
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)
global _FLASHCOMM2_OTP
global _FLASHCOMM2_ODP
_FLASHCOMM2_OTP = None
_FLASHCOMM2_ODP = get_tp_group()
if flashcomm2_otp_size > 1:
otp_group_ranks = []
odp_group_ranks: list[list[int]] = [
[] for _ in range(flashcomm2_otp_size * global_dp_size)
]
for dp_group_index in range(global_dp_size):
for i in range(num_fc2_oproj_tensor_parallel_groups):
ranks = []
for j in range(flashcomm2_otp_size):
rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
ranks.append(rank_idx)
odp_group_index = dp_group_index * flashcomm2_otp_size + j
odp_group_ranks[odp_group_index].append(rank_idx)
otp_group_ranks.append(ranks)
_FLASHCOMM2_OTP = init_model_parallel_group(
otp_group_ranks,
get_world_group().local_rank,
backend,
group_name="flashcomm2_otp")
_FLASHCOMM2_ODP = init_model_parallel_group(
odp_group_ranks,
get_world_group().local_rank,
backend,
group_name="flashcomm2_odp")
def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
@@ -201,3 +257,15 @@ def destroy_ascend_model_parallel():
if _P_TP:
_P_TP.destroy()
_P_TP = None
global _FLASHCOMM2_OTP
if _FLASHCOMM2_OTP and get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size != 1:
_FLASHCOMM2_OTP.destroy()
_FLASHCOMM2_OTP = None
global _FLASHCOMM2_ODP
if _FLASHCOMM2_ODP and get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size != 1:
_FLASHCOMM2_ODP.destroy()
_FLASHCOMM2_ODP = None