diff --git a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py index 3a9733b..5a00592 100644 --- a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py +++ b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py @@ -191,13 +191,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - # Mock the gate function for rm_router_logits=False case - mock_gate = MagicMock() - mock_gate.return_value = (router_logits.repeat(2, 1), None) - - h_out, r_out, _ = layer.prepare(hidden_states, - router_logits, - gate=mock_gate) + h_out, r_out, _ = layer.prepare(hidden_states, router_logits) # After all-gather with DP=2, should double the batch size self.assertEqual(h_out.shape[0], 12) @@ -258,14 +252,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - # Mock gate for router logits recomputation - mock_gate = MagicMock() - mock_gate.return_value = (torch.randn(7, 2), None) - # Run prepare - h_out, r_out, _ = layer.prepare(hidden_states, - router_logits, - gate=mock_gate) + h_out, r_out, _ = layer.prepare(hidden_states, router_logits) # Should be global tensor: [7, 8] and [7, 2] self.assertEqual(h_out.shape, (7, 8)) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 3826a19..76dcd50 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -63,7 +63,7 @@ class TestMoECommMethod(TestBase): # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( - hidden_states, router_logits, False, False, None) + hidden_states, router_logits, False, False) # Test finalize method comm_impl.finalize(h_out, reduce_results=True) @@ -108,7 +108,7 @@ class TestMoECommMethod(TestBase): # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( - hidden_states, router_logits, False, False, None) + hidden_states, router_logits, False, False) # Test finalize method comm_impl.finalize(h_out, reduce_results=True) @@ -153,7 +153,7 @@ class TestMoECommMethod(TestBase): # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( - hidden_states, router_logits, False, False, None) + hidden_states, router_logits, False, False) @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py index 19e4989..6159523 100644 --- a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -26,7 +26,7 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig -from vllm_ascend.utils import enable_sp, get_rm_router_logits_state +from vllm_ascend.utils import enable_sp class FusedMoEPrepareAndFinalize(ABC): @@ -43,31 +43,26 @@ class FusedMoEPrepareAndFinalize(ABC): def __init__(self, moe_config: FusedMoEConfig): self.moe_config = moe_config - is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256 - self.rm_router_logits = get_rm_router_logits_state( - self.moe_config.ep_size, self.moe_config.dp_size, - is_deepseek_v3_r1) @abstractmethod - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Prepare tensors before MoE computation. May involve: - Padding to align communication boundaries - Slicing across tensor-parallel ranks - Broadcasting across data-parallel ranks - - Recomputing router logits if needed Args: hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size] router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts] enable_shared_expert_dp (bool): Skip DP communication for shared experts replace_allreduce (bool): Bypass default all-reduce behavior - gate (nn.Module, optional): Gate network to recompute router_logits if needed Returns: Tuple of: @@ -116,12 +111,13 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: 1. Fetch `mc2_mask` and target padding length from forward context. @@ -214,12 +210,13 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: 1. Pad hidden_states and router_logits to next multiple of TP size. @@ -307,12 +304,13 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): TP AG → Attn → TP RS → EP AG → MoE → EP RS """ - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: AllGather hidden_states and router_logits to form global tensors. @@ -325,7 +323,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, - replace_allreduce, gate) + replace_allreduce) def _prepare_with_ep_group( self, @@ -340,12 +338,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): return hidden_states, router_logits, None def _prepare_with_dp_group( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: 1. Fetch max token count across DP group from forward context. @@ -365,18 +363,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): if pad_size > 0: hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size)) - if not self.rm_router_logits: - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) # All-gather across DP group hidden_states = self.moe_config.dp_group.all_gather( hidden_states, 0) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) # Recompute globally - else: - router_logits = self.moe_config.dp_group.all_gather( - router_logits, 0) + router_logits = self.moe_config.dp_group.all_gather( + router_logits, 0) return hidden_states, router_logits, None @@ -472,12 +466,13 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): get_dp_group().broadcast(buffer[start:end, :], idx) return buffer - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: 1. Fetch cumulative token boundaries from forward context. @@ -493,11 +488,8 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): ).dp_metadata.cu_tokens_across_sp(1) hidden_states = self._naive_multicast(hidden_states, self.cu_tokens_across_dp_cpu) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) - else: - router_logits = self._naive_multicast( - router_logits, self.cu_tokens_across_dp_cpu) + router_logits = self._naive_multicast(router_logits, + self.cu_tokens_across_dp_cpu) return hidden_states, router_logits, None diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index 29a5819..8f49841 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -63,15 +63,16 @@ class MoECommMethod(ABC): self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( ) - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare( hidden_states, router_logits, enable_shared_expert_dp, - replace_allreduce, gate) + replace_allreduce) self.mc2_mask = mc2_mask return hidden_states, router_logits diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 2e9e8fa..18b9ceb 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -48,12 +48,12 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding -from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor, +from vllm_ascend.torchair.utils import (get_all_reduce_merge_state, + get_rm_router_logits_state, + npu_stream_switch, npu_wait_tensor, super_kernel) from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, - get_all_reduce_merge_state, - get_ascend_soc_version, - get_rm_router_logits_state, is_310p, + get_ascend_soc_version, is_310p, is_hierarchical_communication_enabled) diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 97fc3b1..1936703 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -15,6 +15,8 @@ try: except ImportError: from torchair.ops import NpuStreamSwitch as _npu_stream_switch from torchair.ops import npu_wait_tensor as _npu_wait_tensor + +import vllm_ascend.envs as envs_ascend from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes" @@ -241,3 +243,33 @@ def torchair_ops_patch(): def super_kernel(prefix: str, option: str, enabled: bool = True): return _super_kernel(prefix, option) if enabled else nullcontext() + + +# TODO(ttanzhiqiang): rm_router_logits +# dp>1 will trigger +# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. +def get_rm_router_logits_state(ep_size: int, dp_size: int, + is_deepseek_v3_r1: bool): + # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep + # only supports deepseek v3/r1 + if dp_size > 1: + if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + and is_deepseek_v3_r1): + return True + elif ep_size == 1 and is_deepseek_v3_r1: + return True + return False + + +# TODO(ttanzhiqiang): all_reduce merge +# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce +# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model. +def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): + # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep + # only supports deepseek v3/r1 + if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + and is_deepseek_v3_r1): + return True + elif ep_size == 1 and is_deepseek_v3_r1: + return True + return False diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 34b98af..0184bea 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -520,36 +520,6 @@ class ProfileExecuteDuration: return durations -# TODO(ttanzhiqiang): rm_router_logits -# dp>1 will trigger -# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. -def get_rm_router_logits_state(ep_size: int, dp_size: int, - is_deepseek_v3_r1: bool): - # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep - # only supports deepseek v3/r1 - if dp_size > 1: - if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 - and is_deepseek_v3_r1): - return True - elif ep_size == 1 and is_deepseek_v3_r1: - return True - return False - - -# TODO(ttanzhiqiang): all_reduce merge -# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce -# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model. -def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): - # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep - # only supports deepseek v3/r1 - if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 - and is_deepseek_v3_r1): - return True - elif ep_size == 1 and is_deepseek_v3_r1: - return True - return False - - def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): """Register Ascend CustomOP