From ba28d54f359278c83ed0f55ba78de9b1996f2b12 Mon Sep 17 00:00:00 2001 From: AlvisGong Date: Sun, 14 Dec 2025 09:34:13 +0800 Subject: [PATCH] [Perf]enable prefill flashcommon3 (#4065) ### What this PR does / why we need it? moe multistream overlap to improve the performance. ### How was this patch tested? --additional-config '{"multistream_overlap_gate": true}' - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: AlvisGong Signed-off-by: chenxiao Co-authored-by: clrs97 <524936896@qq.com> Co-authored-by: zzhx1 Co-authored-by: chenxiao --- tests/ut/ops/test_prepare_finalize.py | 4 + vllm_ascend/ascend_config.py | 2 + vllm_ascend/distributed/parallel_state.py | 22 +++- vllm_ascend/distributed/utils.py | 33 +++++- vllm_ascend/flash_common3_context.py | 42 +++++++ vllm_ascend/ops/fused_moe/fused_moe.py | 108 ++++++++++++++---- vllm_ascend/ops/fused_moe/prepare_finalize.py | 34 +++++- vllm_ascend/quantization/w8a8_dynamic.py | 34 ++++-- 8 files changed, 239 insertions(+), 40 deletions(-) create mode 100644 vllm_ascend/flash_common3_context.py diff --git a/tests/ut/ops/test_prepare_finalize.py b/tests/ut/ops/test_prepare_finalize.py index 35cb01a7..fe2932a9 100644 --- a/tests/ut/ops/test_prepare_finalize.py +++ b/tests/ut/ops/test_prepare_finalize.py @@ -13,6 +13,10 @@ class TestPrepareAndFinalize(unittest.TestCase): def setUp(self): # Mock FusedMoEConfig + fake_stream = MagicMock() + patcher = patch("torch.npu.Stream", return_value=fake_stream) + patcher.start() + self.addCleanup(patcher.stop) self.moe_config = MagicMock(spec=FusedMoEConfig) self.moe_config.tp_group = MagicMock() self.moe_config.tp_group.device_group = MagicMock() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 1f87f480..f7522320 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -106,6 +106,8 @@ class AscendConfig: enable_shared_expert_dp=True) self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) + self.multistream_overlap_gate = additional_config.get( + "multistream_overlap_gate", False) self.recompute_scheduler_enable = additional_config.get( "recompute_scheduler_enable", False) self.enable_cpu_binding = additional_config.get( diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 7af091b2..ce932fbb 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -20,9 +20,10 @@ _OTP: Optional[GroupCoordinator] = None _LMTP: Optional[GroupCoordinator] = None _EMBED_TP: Optional[GroupCoordinator] = None -# flashcomm2 specific groups +# flashcomm specific groups _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None _FLASHCOMM2_ODP: Optional[GroupCoordinator] = None +_FC3_QUANT_X: Optional[GroupCoordinator] = None # shared_weight across rank groups _SHARED_WEIGHT: Optional[GroupCoordinator] = None @@ -241,6 +242,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1" _SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared") + if get_ascend_config().multistream_overlap_gate: + global _FC3_QUANT_X + group_ranks = all_ranks.unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _FC3_QUANT_X = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="fc3_quant_x") + def model_parallel_initialized(): return (_MC2 is not None) @@ -296,6 +306,11 @@ def get_p_tp_group() -> GroupCoordinator: return _P_TP +def get_fc3_quant_x_group() -> GroupCoordinator: + assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized") + return _FC3_QUANT_X + + def destroy_ascend_model_parallel(): global _MC2 if _MC2: @@ -343,3 +358,8 @@ def destroy_ascend_model_parallel(): if _SHARED_WEIGHT: _SHARED_WEIGHT.destroy() _SHARED_WEIGHT = None + + global _FC3_QUANT_X + if _FC3_QUANT_X: + _FC3_QUANT_X.destroy() + _FC3_QUANT_X = None diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py index c25c1f15..6b4b894e 100644 --- a/vllm_ascend/distributed/utils.py +++ b/vllm_ascend/distributed/utils.py @@ -2,8 +2,11 @@ import os import torch import torch.distributed as dist +from vllm.forward_context import get_forward_context -from vllm_ascend.distributed.parallel_state import get_p_tp_group +from vllm_ascend.distributed.parallel_state import (get_dp_group, + get_fc3_quant_x_group, + get_p_tp_group) def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, @@ -59,3 +62,31 @@ def get_transfer_timeout_value(): '7')) # type: ignore return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000) + + +def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return x + x = get_fc3_quant_x_group().all_gather(x, 0) + dp_metadata = forward_context.dp_metadata + if dp_metadata is None: + pad_size = forward_context.pad_size + if pad_size > 0: + x = x[:-pad_size] + else: + # unpad + num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu + result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), + device=x.device, + dtype=x.dtype) + dp_size = get_dp_group().world_size + x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) + offset = 0 + for idx in range(dp_size): + num_tokens_dp = num_tokens_across_dp_cpu[idx] + result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp] + offset += num_tokens_dp + x = result + return x diff --git a/vllm_ascend/flash_common3_context.py b/vllm_ascend/flash_common3_context.py new file mode 100644 index 00000000..a579af90 --- /dev/null +++ b/vllm_ascend/flash_common3_context.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from vllm.model_executor.layers.linear import LinearBase + + +@dataclass +class FlashCommon3Context: + gate: Optional[LinearBase] = None + topk_weights: Optional[torch.Tensor] = None + topk_ids: Optional[torch.Tensor] = None + row_idx: Optional[torch.Tensor] = None + shared_experts: Optional[torch.nn.Module] = None + shared_out: Optional[torch.Tensor] = None + + +_flash_common3_context: Optional[FlashCommon3Context] = None + + +def get_flash_common3_context() -> Optional[FlashCommon3Context]: + return _flash_common3_context + + +def set_flash_common3_context( + topk_weights: Optional[torch.Tensor] = None, + topk_ids: Optional[torch.Tensor] = None, + shared_experts: Optional[torch.nn.Module] = None, + shared_out: Optional[torch.Tensor] = None, +): + global _flash_common3_context + if _flash_common3_context is None: + _flash_common3_context = FlashCommon3Context() + + if topk_weights is not None: + _flash_common3_context.topk_weights = topk_weights + if topk_ids is not None: + _flash_common3_context.topk_ids = topk_ids + if shared_experts is not None: + _flash_common3_context.shared_experts = shared_experts + if shared_out is not None: + _flash_common3_context.shared_out = shared_out diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index aa039173..4ce526e2 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -37,9 +37,12 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map from vllm_ascend.eplb.utils import moe_load_async_stream +from vllm_ascend.flash_common3_context import (get_flash_common3_context, + set_flash_common3_context) from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method +from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, + setup_moe_comm_method) from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType from vllm_ascend.quantization.w4a8_dynamic import \ AscendW4A8DynamicFusedMoEMethod @@ -139,6 +142,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): class AscendFusedMoE(FusedMoE): moe_counter = -1 + gate_stream: Optional[torch.npu.Stream] = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -170,6 +174,10 @@ class AscendFusedMoE(FusedMoE): self.expert_map_path = ascend_config.expert_map_path self.global_redundant_expert_num = ascend_config.init_redundancy_expert self.global_num_experts = num_experts + self.global_redundant_expert_num + # flashcommon3 gate stream + self.multistream_overlap_gate = ascend_config.multistream_overlap_gate + if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None: + AscendFusedMoE.gate_stream = torch.npu.Stream() if self.custom_routing_function is None and self.e_score_correction_bias is not None: vllm_config = get_current_vllm_config() self.e_score_correction_bias.data = self.e_score_correction_bias.data.to( @@ -332,6 +340,47 @@ class AscendFusedMoE(FusedMoE): enable_force_load_balance = forward_context.in_profile_run forward_context = get_forward_context() + if self.multistream_overlap_gate: + assert AscendFusedMoE.gate_stream is not None + fc3_context = get_flash_common3_context() + assert fc3_context is not None + AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream()) + with npu_stream_switch(AscendFusedMoE.gate_stream, + enabled=self.multistream_overlap_gate): + # share_expert + assert fc3_context.shared_experts is not None + shared_out = fc3_context.shared_experts(hidden_states) + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ + and not shared_expert_dp_enabled(): + shared_out = tensor_model_parallel_all_reduce(shared_out) + set_flash_common3_context(shared_out=shared_out) + + topk_weights, topk_ids = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + global_num_experts=self.global_num_experts) + + if isinstance(forward_context.moe_comm_method, + AllGatherCommImpl): + topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + topk_weights, True, True) + topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + topk_ids, True, True) + + set_flash_common3_context(topk_weights=topk_weights, + topk_ids=topk_ids) + hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, @@ -339,6 +388,10 @@ class AscendFusedMoE(FusedMoE): enable_shared_expert_dp=self.enable_shared_expert_dp, quant_type=self.quant_type) + # Make sure the default stream waits for the gate stream to finish. + if self.multistream_overlap_gate: + torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream) + if isinstance(hidden_states, tuple): hidden_states, pertoken_scale = hidden_states else: @@ -407,6 +460,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): self.shared_expert_stream = None ascend_config = get_ascend_config() self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert + self.multistream_overlap_gate = ascend_config.multistream_overlap_gate if enable_sp(): logger.info_once( "Sequence parallelism is enabled, shared experts are replicated for best performance." @@ -443,30 +497,42 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - # Make sure the shared experts stream begins after hidden_states are ready. - if self.multistream_overlap_shared_expert: - shared_experts_calculation_stream().wait_stream( # type: ignore - torch.npu.current_stream()) - with npu_stream_switch(shared_experts_calculation_stream(), - enabled=self.multistream_overlap_shared_expert): - # Use a separate stream to run shared experts. - # Note that currently we only support calculations in separate streams with aclgraph. - # Communication operations in another stream might cause unknown errors. - shared_out = self._shared_experts(hidden_states) + shared_out = None + if not self.multistream_overlap_gate: + # Make sure the shared experts stream begins after hidden_states are ready. + if self.multistream_overlap_shared_expert: + shared_experts_calculation_stream( + ).wait_stream( # type: ignore + torch.npu.current_stream()) + with npu_stream_switch( + shared_experts_calculation_stream(), + enabled=self.multistream_overlap_shared_expert): + # Use a separate stream to run shared experts. + shared_out = self._shared_experts(hidden_states) + else: + set_flash_common3_context(shared_experts=self._shared_experts) fused_output = AscendFusedMoE.forward_impl( self, hidden_states=hidden_states, router_logits=router_logits, ) - # Make sure the default stream waits for the shared experts stream to finish. - if self.multistream_overlap_shared_expert: - torch.npu.current_stream().wait_stream( - shared_experts_calculation_stream()) - # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` - forward_context = get_forward_context() - moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \ - and not shared_expert_dp_enabled(): - shared_out = tensor_model_parallel_all_reduce(shared_out) + + if not self.multistream_overlap_gate: + # Make sure the default stream waits for the shared experts stream to finish. + if self.multistream_overlap_shared_expert: + torch.npu.current_stream().wait_stream( + shared_experts_calculation_stream()) + + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ + and not shared_expert_dp_enabled(): + shared_out = tensor_model_parallel_all_reduce(shared_out) + else: + fc3_context = get_flash_common3_context() + assert fc3_context is not None + shared_out = fc3_context.shared_out + return shared_out, fused_output diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index b3b907b0..2e7db621 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -29,7 +29,10 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig -from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl +from vllm_ascend.utils import (enable_sp, npu_stream_switch, + prefill_context_parallel_enable) class QuantType(Enum): @@ -49,9 +52,14 @@ class PrepareAndFinalize(ABC): moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info, sizes, ranks, and communication settings. """ + quant_stream: Optional[torch.npu.Stream] = None def __init__(self, moe_config: FusedMoEConfig): self.moe_config = moe_config + ascend_config = get_ascend_config() + self.multistream_overlap_gate = ascend_config.multistream_overlap_gate + if self.multistream_overlap_gate and PrepareAndFinalize.quant_stream is None: + PrepareAndFinalize.quant_stream = torch.npu.Stream() @abstractmethod def prepare( @@ -335,12 +343,28 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): if quant_type == QuantType.W8A8: hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( hidden_states) + + if self.multistream_overlap_gate: + assert PrepareAndFinalize.quant_stream is not None + PrepareAndFinalize.quant_stream.wait_stream( + torch.npu.current_stream()) + with npu_stream_switch(PrepareAndFinalize.quant_stream, + enabled=self.multistream_overlap_gate): + hidden_states = fc3_all_gather_and_maybe_unpad_impl( + hidden_states) + else: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states, True, True) + router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + router_logits, True, True) + + if pertoken_scale is not None: pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( pertoken_scale, True, True) - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states, True, True) - router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - router_logits, True, True) + + if self.multistream_overlap_gate: + torch.npu.current_stream().wait_stream( + PrepareAndFinalize.quant_stream) if pertoken_scale is not None: return (hidden_states, pertoken_scale), router_logits, None, None diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 7885ef14..72a3570e 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -26,6 +26,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.flash_common3_context import get_flash_common3_context from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz @@ -114,6 +115,7 @@ class AscendW8A8DynamicFusedMoEMethod: self.use_aclgraph = (vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE and not vllm_config.model_config.enforce_eager) + self.multistream_overlap_gate = ascend_config.multistream_overlap_gate self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path self.in_dtype = vllm_config.model_config.dtype @@ -198,18 +200,25 @@ class AscendW8A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) + topk_weights, topk_ids = None, None + if self.multistream_overlap_gate: + fc3_context = get_flash_common3_context() + assert fc3_context is not None + topk_weights = fc3_context.topk_weights + topk_ids = fc3_context.topk_ids + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. @@ -222,6 +231,7 @@ class AscendW8A8DynamicFusedMoEMethod: topk_ids = torch.argsort( random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) + assert topk_weights is not None topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method