[Feat] flashcomm_v2 optim solution (#3232)
### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
This commit is contained in:
@@ -11,7 +11,8 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model
|
||||
from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx,
|
||||
is_moe_model)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
@@ -121,13 +122,17 @@ def set_ascend_forward_context(
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
forward_context.num_tokens = num_tokens
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
|
||||
if sp_enabled:
|
||||
if (forward_context.sp_enabled
|
||||
or forward_context.flashcomm_v2_enabled):
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
forward_context.num_tokens = num_tokens
|
||||
|
||||
# set this for rope forward_oot using
|
||||
forward_context.is_first_layer = True
|
||||
@@ -179,7 +184,8 @@ def set_ascend_forward_context(
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = \
|
||||
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if sp_enabled:
|
||||
if (forward_context.sp_enabled
|
||||
or forward_context.flashcomm_v2_enabled):
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
|
||||
Reference in New Issue
Block a user