diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index 94a52b7c..96daaa25 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -72,9 +72,6 @@ def setup_vllm_config_mock(mocker: MockerFixture): mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config', return_value=mock_vllm_config) - mocker.patch( - 'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config', - return_value=mock_vllm_config) @pytest.fixture diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 8adde876..7620999a 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -26,7 +26,6 @@ class TestMoECommMethod(TestBase): self.moe_config.dp_group = MagicMock() self.moe_config.num_global_redundant_experts = 0 - @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather" @@ -36,11 +35,7 @@ class TestMoECommMethod(TestBase): ) def test_all_gather_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, - mock_get_forward_context, - mock_get_current_vllm_config): - # Mock vLLM config - mock_get_current_vllm_config.return_value = MagicMock() - + mock_get_forward_context): # Mock forward context mock_context = MagicMock() mock_context.moe_comm_method = "all_gather" @@ -76,17 +71,12 @@ class TestMoECommMethod(TestBase): context_metadata=context_metadata) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) - @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2") def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, - mock_get_forward_context, - mock_get_current_vllm_config): - # Mock vLLM config - mock_get_current_vllm_config.return_value = MagicMock() - + mock_get_forward_context): # Mock forward context mock_context = MagicMock() mock_context.moe_comm_method = "mc2" @@ -124,7 +114,6 @@ class TestMoECommMethod(TestBase): context_metadata=context_metadata) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) - @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All" @@ -134,11 +123,7 @@ class TestMoECommMethod(TestBase): ) def test_alltoall_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, - mock_get_forward_context, - mock_get_current_vllm_config): - # Mock vLLM config - mock_get_current_vllm_config.return_value = MagicMock() - + mock_get_forward_context): # Mock forward context mock_context = MagicMock() mock_context.moe_comm_method = "alltoall" @@ -168,7 +153,6 @@ class TestMoECommMethod(TestBase): mock_pf_instance.prepare.assert_called_once_with( hidden_states, router_logits, False, False, QuantType.NONE) - @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather" @@ -179,11 +163,7 @@ class TestMoECommMethod(TestBase): @patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp") def test_fused_experts_method(self, mock_unified_apply_mlp, mock_token_dispatcher, mock_prepare_finalize, - mock_get_forward_context, - mock_get_current_vllm_config): - # Mock vLLM config - mock_get_current_vllm_config.return_value = MagicMock() - + mock_get_forward_context): # Mock forward context mock_context = MagicMock() mock_context.moe_comm_method = "all_gather" diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index ebc22bd1..8106c935 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -240,7 +240,6 @@ def select_moe_comm_method(num_tokens: int, quant_type = getattr( vllm_config.model_config.hf_config, 'moe_quantize', getattr(vllm_config.model_config.hf_config, 'quantize', None)) - model_type = vllm_config.model_config.hf_config.model_type if not vllm_config.parallel_config.enable_expert_parallel: moe_comm_type = MoECommType.ALLGATHER @@ -267,7 +266,4 @@ def select_moe_comm_method(num_tokens: int, if fused_all2all_enable else MoECommType.ALLTOALL) else: raise ValueError(f"Unsupported soc_version: {soc_version}") - # PanguProMoE only supports allgather - if model_type == "PanguProMoE": - moe_comm_type = MoECommType.ALLGATHER return moe_comm_type diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index d0afa7bf..93b79242 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -19,7 +19,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Optional import torch -from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig @@ -30,7 +29,7 @@ from vllm_ascend.ops.fused_moe.prepare_finalize import ( PrepareAndFinalizeWithMC2, QuantType) from vllm_ascend.ops.fused_moe.token_dispatcher import ( TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, - TokenDispatcherWithMC2, TokenDispatcherWithMoge) + TokenDispatcherWithMC2) _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} @@ -52,8 +51,6 @@ class MoECommMethod(ABC): """Base class for MoE communication methods.""" def __init__(self, moe_config: FusedMoEConfig): - self.model_type = get_current_vllm_config( - ).model_config.hf_config.model_type self.moe_config = moe_config self.token_dispatcher = self._get_token_dispatcher() @@ -198,16 +195,10 @@ class AllGatherCommImpl(MoECommMethod): """ def _get_token_dispatcher(self): - if self.model_type == "PanguProMoE": - return TokenDispatcherWithMoge( - top_k=self.moe_config.experts_per_token, - num_experts=self.moe_config.num_experts, - num_local_experts=self.moe_config.num_local_experts) - else: - return TokenDispatcherWithAllGather( - top_k=self.moe_config.experts_per_token, - num_experts=self.moe_config.num_experts, - num_local_experts=self.moe_config.num_local_experts) + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) def _get_prepare_finalize(self): return PrepareAndFinalizeWithAllGather(self.moe_config) diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 1246d648..1b18a488 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -422,69 +422,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): return final_hidden_states -# mypy: disable-error-code="override" -class TokenDispatcherWithMoge(MoETokenDispatcher): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.apply_router_weight_on_input = False - self.local_num_experts = self.num_experts // self.ep_size - self.local_num_group = self.top_k // self.ep_size - self.bsz = None - - def token_dispatch(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False, - pertoken_scale: Optional[torch.Tensor] = None): - self.bsz, _ = hidden_states.shape - flatten_topk_ids = topk_ids.view(-1) - self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) - self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32) - sorted_hidden_states = hidden_states.index_select( - 0, self.sorted_topk_ids // self.local_num_group) - - experts_id = torch.arange(0, - self.local_num_experts, - dtype=topk_ids.dtype, - device=topk_ids.device) - num_tokens_per_expert = ( - flatten_topk_ids.unsqueeze(-1) == experts_id).to( - torch.float32).sum(0) - topk_scales = topk_weights.view(-1).index_select( - 0, self.sorted_topk_ids).unsqueeze(-1) - group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - group_list_type = 0 - return { - "group_list_type": group_list_type, - "hidden_states": sorted_hidden_states, - "group_list": group_list, - "topk_scales": topk_scales - } - - def token_combine(self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None): - unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( - torch.int32) - unsorted_hidden_states = hidden_states.index_select( - 0, unsorted_topk_ids) - final_hidden_states = unsorted_hidden_states.reshape( - self.bsz, self.top_k // self.ep_size, -1).sum(1) - return final_hidden_states - - class TokenDispatcherWithAll2AllV(MoETokenDispatcher): """ The implementation of the AlltoAll-based token dispatcher, which handles token