[Feat] flashcomm_v2 optim solution (#3232)

### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
This commit is contained in:
Levi
2025-11-10 11:01:45 +08:00
committed by GitHub
parent b1a00e0512
commit 0a62e671fb
12 changed files with 380 additions and 24 deletions

View File

@@ -132,6 +132,12 @@ env_variables: Dict[str, Callable[[], Any]] = {
# This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
# Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it.
# The specific value set will be used as the O-matrix TP group size for flashcomm2.
# For a detailed introduction to the parameters and the differences and applicable scenarios
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
@@ -185,4 +191,4 @@ def __getattr__(name: str):
def __dir__():
return list(env_variables.keys())
return list(env_variables.keys())