[Pangu][MoE] Remove PanguProMoEV1 related code (#5088)

### What this PR does / why we need it?
PanguProMoEV1 is no longer supported in vllm-ascend, remove related
code.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
e2e & ut

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: weichen <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-12-17 16:14:42 +08:00
committed by GitHub
parent 3f7a2fba70
commit f0060fc822
5 changed files with 9 additions and 108 deletions

View File

@@ -72,9 +72,6 @@ def setup_vllm_config_mock(mocker: MockerFixture):
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch(
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config)
@pytest.fixture

View File

@@ -26,7 +26,6 @@ class TestMoECommMethod(TestBase):
self.moe_config.dp_group = MagicMock()
self.moe_config.num_global_redundant_experts = 0
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
@@ -36,11 +35,7 @@ class TestMoECommMethod(TestBase):
)
def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
mock_get_forward_context):
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
@@ -76,17 +71,12 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
mock_get_forward_context):
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "mc2"
@@ -124,7 +114,6 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
@@ -134,11 +123,7 @@ class TestMoECommMethod(TestBase):
)
def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
mock_get_forward_context):
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "alltoall"
@@ -168,7 +153,6 @@ class TestMoECommMethod(TestBase):
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, QuantType.NONE)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
@@ -179,11 +163,7 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
mock_get_forward_context):
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"

View File

@@ -240,7 +240,6 @@ def select_moe_comm_method(num_tokens: int,
quant_type = getattr(
vllm_config.model_config.hf_config, 'moe_quantize',
getattr(vllm_config.model_config.hf_config, 'quantize', None))
model_type = vllm_config.model_config.hf_config.model_type
if not vllm_config.parallel_config.enable_expert_parallel:
moe_comm_type = MoECommType.ALLGATHER
@@ -267,7 +266,4 @@ def select_moe_comm_method(num_tokens: int,
if fused_all2all_enable else MoECommType.ALLTOALL)
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
# PanguProMoE only supports allgather
if model_type == "PanguProMoE":
moe_comm_type = MoECommType.ALLGATHER
return moe_comm_type

View File

@@ -19,7 +19,6 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
@@ -30,7 +29,7 @@ from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalizeWithMC2, QuantType)
from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
TokenDispatcherWithMC2)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
@@ -52,8 +51,6 @@ class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type
self.moe_config = moe_config
self.token_dispatcher = self._get_token_dispatcher()
@@ -198,12 +195,6 @@ class AllGatherCommImpl(MoECommMethod):
"""
def _get_token_dispatcher(self):
if self.model_type == "PanguProMoE":
return TokenDispatcherWithMoge(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
else:
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,

View File

@@ -422,69 +422,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
return final_hidden_states
# mypy: disable-error-code="override"
class TokenDispatcherWithMoge(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.local_num_experts = self.num_experts // self.ep_size
self.local_num_group = self.top_k // self.ep_size
self.bsz = None
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, self.sorted_topk_ids // self.local_num_group)
experts_id = torch.arange(0,
self.local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, self.sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": group_list,
"topk_scales": topk_scales
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32)
unsorted_hidden_states = hidden_states.index_select(
0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
self.bsz, self.top_k // self.ep_size, -1).sum(1)
return final_hidden_states
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
"""
The implementation of the AlltoAll-based token dispatcher, which handles token