[Feat] flashcomm_v2 optim solution (#3232)

### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
This commit is contained in:
Levi
2025-11-10 11:01:45 +08:00
committed by GitHub
parent b1a00e0512
commit 0a62e671fb
12 changed files with 380 additions and 24 deletions

View File

@@ -814,3 +814,68 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool:
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer")
return _HAS_LAYER_IDX
def flashcomm2_enable() -> bool:
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
vllm_config):
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
if not flashcomm2_enable():
logger.info("FLASHCOMM2 not enable.")
return flashcomm2_oproj_tp_size
logger.info(
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}"
)
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
logger.warning_once(
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
)
if ascend_config.oproj_tensor_parallel_size is not None:
raise AssertionError(
"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size"
)
if global_tp_size <= flashcomm2_oproj_tp_size:
raise AssertionError(
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})"
)
if global_tp_size % flashcomm2_oproj_tp_size != 0:
raise AssertionError(
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
)
if vllm_config.kv_transfer_config is None:
logger.warning_once(
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation."
)
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"FLASHCOMM2 primarily targets P-scenario deployments, "
"with additional support for hybrid deployment scenarios. "
"It is not applicable in D-scenario environments.")
return flashcomm2_oproj_tp_size
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain.
# For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,
# the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]].
flashcomm2_otp_size = get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size
num_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)
reorgnized_batch_ids = []
for i in range(num_oproj_tensor_parallel_groups):
ranks = []
for j in range(flashcomm2_otp_size):
rank_idx = i + j * num_oproj_tensor_parallel_groups
ranks.append(rank_idx)
reorgnized_batch_ids.append(ranks)
return reorgnized_batch_ids