[Feat] Flash comm allgher ep (#3334)

Support flash comm v1(Sequence Parallelism) for Allgather EP.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
realliujiaxu
2025-10-15 19:36:32 +08:00
committed by GitHub
parent 8abe517870
commit f69a83b7ba
15 changed files with 283 additions and 78 deletions

View File

@@ -166,7 +166,7 @@ def test_sp_for_qwen3_moe() -> None:
@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS) @pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager): def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
example_prompts = [ example_prompts = [
"Hello, my name is", "Hello, my name is",

View File

@@ -500,9 +500,12 @@ class TestAscendMLAImpl(TestBase):
mock_up_proj.assert_called_once() mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once() mock_npu_fused_infer_attention_score.assert_called_once()
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch): def test_mla_preprocess(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad):
magic_npu_fetch.return_value = MagicMock() magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
batch_size = 4 batch_size = 4
seq_len = 8 seq_len = 8
hidden_size = 1024 hidden_size = 1024

View File

@@ -42,9 +42,11 @@ def test_row_parallel_linear(cls, mock_distributed):
assert output[0].shape == (2, 4, 64) assert output[0].shape == (2, 4, 64)
@patch("vllm_ascend.models.layers.mla.get_forward_context")
@patch("torch.ops.vllm.mla_forward") @patch("torch.ops.vllm.mla_forward")
@patch("torch_npu.npu_rms_norm") @patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
mock_forward_context,
mock_distributed, base_config): mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
# Make a fake ascend config because of the AscendLinearBase # Make a fake ascend config because of the AscendLinearBase
@@ -54,6 +56,9 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
vllm_config.parallel_config.tensor_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config) ascend_config.init_ascend_config(vllm_config)
dummy_forward_context = MagicMock()
dummy_forward_context.sp_enabled = False
mock_forward_context.return_value = dummy_forward_context
attn = CustomDeepseekV2MLAAttention(config=base_config, attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128, hidden_size=128,

View File

@@ -11,7 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context) set_forward_context)
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp from vllm_ascend.utils import enable_sp, is_moe_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
@@ -112,6 +112,10 @@ def set_ascend_forward_context(
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods. # the performance may degrade due to the switching of communication methods.
if is_moe_model(vllm_config):
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1
else:
sp_enabled = enable_sp(vllm_config) and \ sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and \ tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000 num_tokens is not None and num_tokens > 1000
@@ -121,6 +125,7 @@ def set_ascend_forward_context(
(num_tokens % tp_world_size)) % tp_world_size (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size forward_context.pad_size = pad_size
forward_context.sp_enabled = sp_enabled forward_context.sp_enabled = sp_enabled
forward_context.num_tokens = num_tokens
# set this for rope forward_oot using # set this for rope forward_oot using
forward_context.is_first_layer = True forward_context.is_first_layer = True
@@ -169,8 +174,14 @@ def set_ascend_forward_context(
dp_world_size = get_dp_group().world_size dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None: if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item( max_tokens_across_dp = \
) forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if sp_enabled:
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
else: else:
max_tokens_across_dp = num_tokens max_tokens_across_dp = num_tokens

View File

@@ -9,7 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
@@ -1128,10 +1128,11 @@ class AscendMLAImpl(MLAAttentionImpl):
q_c = hidden_states q_c = hidden_states
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
# Process for shared_expert_dp # Process for Flash Comm V1
if need_gather_q_kv: q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c = get_tp_group().all_gather(q_c, 0) q_c, need_gather_q_kv)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0) kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split, need_gather_q_kv)
decode_preprocess_res = None decode_preprocess_res = None
prefill_preprocess_res = None prefill_preprocess_res = None
if has_prefill: if has_prefill:
@@ -1200,8 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
output_padded = output output_padded = output
output = output[:num_actual_tokens, ...] o_proj_input_shape = (get_forward_context().num_tokens,
o_proj_input_shape = (num_actual_tokens,
self.num_heads * self.v_head_dim) self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape, o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
@@ -1248,7 +1248,8 @@ class AscendMLAImpl(MLAAttentionImpl):
o_proj_input[num_decode_tokens:] = output_prefill o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record() current_ms_metadata.after_comm_event.record()
else: else:
o_proj_input[num_decode_tokens:] = output_prefill o_proj_input[
num_decode_tokens:num_actual_tokens] = output_prefill
# O proj # O proj
current_ms_metadata = get_multistream_comm_context() current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
@@ -1258,20 +1259,14 @@ class AscendMLAImpl(MLAAttentionImpl):
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch) enabled=self.enable_prefetch)
output[...] = self.o_proj( output[...] = self.o_proj(o_proj_input)[0]
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
else: else:
with torch.npu.stream(current_ms_metadata.comm_stream): with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight, maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input, dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch) enabled=self.enable_prefetch)
output[...] = self.o_proj( output[...] = self.o_proj(o_proj_input)[0]
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record() current_ms_metadata.after_comm_event.record()
del o_proj_input del o_proj_input

View File

@@ -133,8 +133,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
# Whether to enable FlashComm optimization when tensor parallel is enabled. # Whether to enable FlashComm optimization when tensor parallel is enabled.
# This feature will get better performance when concurrency is large. # This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM": "VLLM_ASCEND_ENABLE_FLASHCOMM1":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))), lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
# Whether to enable MLP weight prefetch, only used in small concurrency. # Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),

View File

@@ -300,8 +300,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj") prefix=f"{prefix}.kv_b_proj")
self.o_proj = CustomDeepseekV2RowParallelLinear( self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.num_heads * self.v_head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,

View File

@@ -122,19 +122,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None, kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0] need_gather_q_kv = get_forward_context().sp_enabled
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed # FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape, output = torch.empty(output_shape,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,

View File

@@ -38,8 +38,9 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
npu_stream_switch) is_enable_nz, npu_stream_switch,
shared_expert_dp_enabled)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@@ -417,6 +418,10 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
if self.multistream_overlap_shared_expert: if self.multistream_overlap_shared_expert:
self.shared_expert_stream = torch.npu.Stream() self.shared_expert_stream = torch.npu.Stream()
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
)
def forward( def forward(
self, self,
@@ -444,7 +449,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context() forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_output = AscendFusedMoE.forward_impl( fused_output = AscendFusedMoE.forward_impl(
self, self,

View File

@@ -49,7 +49,7 @@ from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group) get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, enable_sp, from vllm_ascend.utils import (dense_optim_enable, enable_sp,
matmul_allreduce_enable, mlp_tp_enable, matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable) oproj_tp_enable, shared_expert_dp_enabled)
class CustomLinearOp: class CustomLinearOp:
@@ -418,7 +418,8 @@ def _get_row_parallel_op(
def get_parallel_op(disable_tp, prefix, layer, direct): def get_parallel_op(disable_tp, prefix, layer, direct):
if disable_tp: if disable_tp or ("shared_experts" in prefix
and shared_expert_dp_enabled()):
return None, 0, 1 return None, 0, 1
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
MLPRowParallelOp, OProjRowParallelOp, MLPRowParallelOp, OProjRowParallelOp,

View File

@@ -27,7 +27,7 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import get_rm_router_logits_state from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
@@ -277,9 +277,24 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
""" """
MoE communication strategy using All-Gather + Reduce-Scatter. MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after. There are two sets of prepare and finalize:
Uses `max_tokens_across_dp` from forward_context for padding alignment. 1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled,
we gather inputs across DP ranks before MoE, scatter outputs after.
The communication and calculation process is as follows (AG, AR and RS
are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively):
Attn → TP AR → DP AG → MoE → DP RS → TP AR
2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled,
the above process becomes:
TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS
This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS
into EP Reduce-Scatter to improve communication performance. The optimized process is as follows:
TP AG → Attn → TP RS → EP AG → MoE → EP RS
""" """
def prepare( def prepare(
@@ -289,6 +304,42 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
gate=None gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
AllGather hidden_states and router_logits to form global tensors.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits)
return self._prepare_with_dp_group(hidden_states, router_logits,
enable_shared_expert_dp,
replace_allreduce, gate)
def _prepare_with_ep_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
return hidden_states, router_logits, None, None
def _prepare_with_dp_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
""" """
@@ -301,7 +352,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
Tuple of (global_hidden_states, global_router_logits, None, None) Tuple of (global_hidden_states, global_router_logits, None, None)
""" """
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1: if self.moe_config.dp_size > 1:
forward_context = get_forward_context() forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp max_tokens_across_dp = forward_context.max_tokens_across_dp
@@ -323,7 +373,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
else: else:
router_logits = self.moe_config.dp_group.all_gather( router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0) router_logits, 0)
return hidden_states, router_logits, None, None return hidden_states, router_logits, None, None
def finalize(self, def finalize(self,
@@ -331,6 +380,36 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
reduce_results: bool, reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor: context_metadata: Optional[dict] = None) -> torch.Tensor:
""" """
Finalization steps:
Reduce Scatter hidden states.
Returns:
Tensor with shape [local_num_tokens, hidden_size]
"""
if enable_sp():
return self._finalize_with_ep_group(hidden_states)
return self._finalize_with_dp_group(hidden_states, reduce_results)
def _finalize_with_ep_group(self,
hidden_states: torch.Tensor) -> torch.Tensor:
"""
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
1. Reduce_results is False usually happens when models have shared experts and need to
allreduce hidden states after results of shared experts and routed experts are added in FusedMoe.
We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the
result of shared experts.
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
here, then skip allreudce in FusedMoe.
"""
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
hidden_states, True)
return hidden_states
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps: Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group. 1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
2. Slice to original local token count. 2. Slice to original local token count.

View File

@@ -1,7 +1,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch_npu import torch_npu
from vllm.distributed import (tensor_model_parallel_all_gather, from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter) tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
@@ -13,8 +15,10 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream from vllm_ascend.utils import npu_stream_switch, prefetch_stream
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, def _maybe_all_gather_and_maybe_unpad_impl(
label: bool) -> torch.Tensor: x: torch.Tensor,
label: bool,
is_ep_comm: bool = False) -> torch.Tensor:
try: try:
forward_context = get_forward_context() forward_context = get_forward_context()
except AssertionError: except AssertionError:
@@ -22,27 +26,66 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
sp_enabled = forward_context.sp_enabled sp_enabled = forward_context.sp_enabled
if sp_enabled and label: if sp_enabled and label:
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
x = tensor_model_parallel_all_gather(x, 0) x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size pad_size = forward_context.pad_size
if pad_size > 0: if pad_size > 0:
x = x[:-pad_size, :] x = x[:-pad_size, :]
else:
x = get_ep_group().all_gather(x, 0)
# unpad
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty(
(num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
device=x.device,
dtype=x.dtype)
dp_size = get_dp_group().world_size
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
offset = 0
for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx]
result[offset:offset +
num_tokens_dp, :] = x[idx, :num_tokens_dp, :]
offset += num_tokens_dp
x = result
return x return x
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: def _maybe_pad_and_reduce_impl(x: torch.Tensor,
is_ep_comm: bool = False) -> torch.Tensor:
try: try:
forward_context = get_forward_context() forward_context = get_forward_context()
except AssertionError: except AssertionError:
return tensor_model_parallel_all_reduce(x) return tensor_model_parallel_all_reduce(x)
sp_enabled = forward_context.sp_enabled if not forward_context.sp_enabled:
if sp_enabled: return tensor_model_parallel_all_reduce(x)
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
pad_size = forward_context.pad_size pad_size = forward_context.pad_size
if pad_size > 0: if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size)) x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0) return tensor_model_parallel_reduce_scatter(x, 0)
else: else:
return tensor_model_parallel_all_reduce(x) # padding
dp_size = get_dp_group().world_size
num_tokens_across_dp_cpu = \
get_forward_context().dp_metadata.num_tokens_across_dp_cpu
padded_x = torch.empty(
(dp_size, forward_context.padded_length, *x.shape[1:]),
device=x.device,
dtype=x.dtype)
offset = 0
for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx]
padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp]
offset += num_tokens_dp
return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]),
0)
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
@@ -71,6 +114,33 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
return return
def _maybe_all_gather_and_maybe_unpad_fake(
x: torch.Tensor,
label: bool,
is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().sp_enabled and label:
return torch.empty(
(x.shape[0] * get_tensor_model_parallel_world_size(),
*x.shape[1:]),
device=x.device,
dtype=x.dtype)
return x
def _maybe_pad_and_reduce_fake(x: torch.Tensor,
is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().sp_enabled:
return torch.empty(
(x.shape[0] // get_tensor_model_parallel_world_size(),
*x.shape[1:]),
device=x.device,
dtype=x.dtype)
return x
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
prefix: str) -> None: prefix: str) -> None:
return return
@@ -158,7 +228,8 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context() forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
} or forward_context.sp_enabled:
return final_hidden_states return final_hidden_states
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
@@ -166,13 +237,13 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl, op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x, fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
mutates_args=[], mutates_args=[],
dispatch_key="PrivateUse1") dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_pad_and_reduce", direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl, op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x, fake_impl=_maybe_pad_and_reduce_fake,
mutates_args=[], mutates_args=[],
dispatch_key="PrivateUse1") dispatch_key="PrivateUse1")

View File

@@ -31,7 +31,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
init_ascend_config) init_ascend_config)
from vllm_ascend.torchair.utils import (check_torchair_cache_exist, from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
delete_torchair_cache_file) delete_torchair_cache_file)
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p, from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p,
update_aclgraph_sizes) update_aclgraph_sizes)
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -211,6 +211,21 @@ class NPUPlatform(Platform):
# set cudaprah sizes before extending `compilation_config.splitting_ops` # set cudaprah sizes before extending `compilation_config.splitting_ops`
vllm_config._set_cudagraph_sizes() vllm_config._set_cudagraph_sizes()
# TODO delete graph size update here when compilation_config.pass_config.enable_sequence_parallelism
# is supported by vllm-ascend.
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \
enable_sp(vllm_config):
original_sizes = compilation_config.cudagraph_capture_sizes
sp_aclgraph_sizes = \
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
assert sp_aclgraph_sizes, (
f"cudagraph_capture_sizes {original_sizes} does not contain"
f"values that are multiples of tp_size "
f"{vllm_config.parallel_config.tensor_parallel_size}")
if len(sp_aclgraph_sizes) != len(original_sizes):
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
vllm_config.compilation_config.init_with_cudagraph_sizes(
sp_aclgraph_sizes)
# TODO: Full graph is fully supported later, and the default value will be set to full graph. # TODO: Full graph is fully supported later, and the default value will be set to full graph.
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:

View File

@@ -55,6 +55,7 @@ _PREFETCH_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False _ASCEND_CUSTOMOP_IS_REIGISTERED = False
_DEFAULT_BUFFER_SIZE = 200 _DEFAULT_BUFFER_SIZE = 200
_MIN_DP_BUFFER_SIZE = 50 _MIN_DP_BUFFER_SIZE = 50
_IS_MOE_MODEL = None
def is_310p(): def is_310p():
@@ -609,12 +610,24 @@ def enable_sp(vllm_config=None) -> bool:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
return ( return (
vllm_config.compilation_config.pass_config.enable_sequence_parallelism vllm_config.compilation_config.pass_config.enable_sequence_parallelism
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM) or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
# TODO remove it after vllm has this func
def shared_expert_dp_enabled() -> bool:
return get_ascend_config().enable_shared_expert_dp or enable_sp()
def is_moe_model(vllm_config: VllmConfig): def is_moe_model(vllm_config: VllmConfig):
global _IS_MOE_MODEL
if _IS_MOE_MODEL is None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
return any('experts' in key.lower() for key in config.to_dict()) _IS_MOE_MODEL = any('experts' in key.lower()
for key in config.to_dict())
return _IS_MOE_MODEL
def weak_ref_tensor(tensor: Any) -> Any: def weak_ref_tensor(tensor: Any) -> Any:

View File

@@ -20,6 +20,7 @@
import copy import copy
import gc import gc
import itertools import itertools
import math
import re import re
import time import time
from collections import defaultdict from collections import defaultdict
@@ -128,8 +129,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration, AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p, is_enable_nz, enable_sp, get_ascend_soc_version, is_310p,
lmhead_tp_enable) is_enable_nz, lmhead_tp_enable)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -1210,6 +1211,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Add padding to the batch size. # Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph( num_input_tokens = self.vllm_config.pad_for_cudagraph(
total_num_scheduled_tokens) total_num_scheduled_tokens)
elif self.use_aclgraph and enable_sp(self.vllm_config):
# When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size,
# the model will fall back to running its FX graph in eager mode.
# In this case, when sequence parallelism is enabled, we need to pad tokens to align
# with tp_size because pad_size cannot be captured by the FX graph
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
num_input_tokens = math.ceil(
total_num_scheduled_tokens / tp_size) * tp_size
else: else:
# Eager mode. # Eager mode.
num_input_tokens = total_num_scheduled_tokens num_input_tokens = total_num_scheduled_tokens
@@ -1850,6 +1859,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
raise ValueError(f"Unsupported soc_version: {soc_version}") raise ValueError(f"Unsupported soc_version: {soc_version}")
if moe_comm_type == MoECommType.ALLGATHER and with_prefill: if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
if enable_sp():
moe_comm_type = MoECommType.ALLGATHER
else:
moe_comm_type = MoECommType.NAIVE_MULTICAST moe_comm_type = MoECommType.NAIVE_MULTICAST
# PanguProMoE only supports allgather # PanguProMoE only supports allgather
@@ -2314,6 +2326,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
} }
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
# If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size.
if self.use_aclgraph and enable_sp(self.vllm_config):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
# Padding for DP # Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill, (num_tokens, num_tokens_across_dp, with_prefill,
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)