[Pangu][MoE] Remove PanguProMoEV1 related code (#5088)
### What this PR does / why we need it?
PanguProMoEV1 is no longer supported in vllm-ascend, remove related
code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -72,9 +72,6 @@ def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
|
||||
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch(
|
||||
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -26,7 +26,6 @@ class TestMoECommMethod(TestBase):
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.num_global_redundant_experts = 0
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
@@ -36,11 +35,7 @@ class TestMoECommMethod(TestBase):
|
||||
)
|
||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
mock_get_forward_context):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "all_gather"
|
||||
@@ -76,17 +71,12 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
mock_get_forward_context):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "mc2"
|
||||
@@ -124,7 +114,6 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
||||
@@ -134,11 +123,7 @@ class TestMoECommMethod(TestBase):
|
||||
)
|
||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
mock_get_forward_context):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "alltoall"
|
||||
@@ -168,7 +153,6 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, QuantType.NONE)
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
@@ -179,11 +163,7 @@ class TestMoECommMethod(TestBase):
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||
mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
mock_get_forward_context):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "all_gather"
|
||||
|
||||
@@ -240,7 +240,6 @@ def select_moe_comm_method(num_tokens: int,
|
||||
quant_type = getattr(
|
||||
vllm_config.model_config.hf_config, 'moe_quantize',
|
||||
getattr(vllm_config.model_config.hf_config, 'quantize', None))
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if not vllm_config.parallel_config.enable_expert_parallel:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
@@ -267,7 +266,4 @@ def select_moe_comm_method(num_tokens: int,
|
||||
if fused_all2all_enable else MoECommType.ALLTOALL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
# PanguProMoE only supports allgather
|
||||
if model_type == "PanguProMoE":
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
return moe_comm_type
|
||||
|
||||
@@ -19,7 +19,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
@@ -30,7 +29,7 @@ from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalizeWithMC2, QuantType)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
||||
TokenDispatcherWithMC2)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
|
||||
@@ -52,8 +51,6 @@ class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.model_type = get_current_vllm_config(
|
||||
).model_config.hf_config.model_type
|
||||
self.moe_config = moe_config
|
||||
|
||||
self.token_dispatcher = self._get_token_dispatcher()
|
||||
@@ -198,12 +195,6 @@ class AllGatherCommImpl(MoECommMethod):
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
if self.model_type == "PanguProMoE":
|
||||
return TokenDispatcherWithMoge(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
else:
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
|
||||
@@ -422,69 +422,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# mypy: disable-error-code="override"
|
||||
class TokenDispatcherWithMoge(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.local_num_experts = self.num_experts // self.ep_size
|
||||
self.local_num_group = self.top_k // self.ep_size
|
||||
self.bsz = None
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
self.bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, self.sorted_topk_ids // self.local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
self.local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (
|
||||
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, self.sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": group_list,
|
||||
"topk_scales": topk_scales
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None):
|
||||
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
|
||||
torch.int32)
|
||||
unsorted_hidden_states = hidden_states.index_select(
|
||||
0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
self.bsz, self.top_k // self.ep_size, -1).sum(1)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
|
||||
Reference in New Issue
Block a user