[Pangu][MoE] Remove PanguProMoEV1 related code (#5088)
### What this PR does / why we need it?
PanguProMoEV1 is no longer supported in vllm-ascend, remove related
code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -72,9 +72,6 @@ def setup_vllm_config_mock(mocker: MockerFixture):
|
|||||||
|
|
||||||
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
|
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
|
||||||
return_value=mock_vllm_config)
|
return_value=mock_vllm_config)
|
||||||
mocker.patch(
|
|
||||||
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
|
|
||||||
return_value=mock_vllm_config)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ class TestMoECommMethod(TestBase):
|
|||||||
self.moe_config.dp_group = MagicMock()
|
self.moe_config.dp_group = MagicMock()
|
||||||
self.moe_config.num_global_redundant_experts = 0
|
self.moe_config.num_global_redundant_experts = 0
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||||
@@ -36,11 +35,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
)
|
)
|
||||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||||
mock_prepare_finalize,
|
mock_prepare_finalize,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context):
|
||||||
mock_get_current_vllm_config):
|
|
||||||
# Mock vLLM config
|
|
||||||
mock_get_current_vllm_config.return_value = MagicMock()
|
|
||||||
|
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.moe_comm_method = "all_gather"
|
mock_context.moe_comm_method = "all_gather"
|
||||||
@@ -76,17 +71,12 @@ class TestMoECommMethod(TestBase):
|
|||||||
context_metadata=context_metadata)
|
context_metadata=context_metadata)
|
||||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context):
|
||||||
mock_get_current_vllm_config):
|
|
||||||
# Mock vLLM config
|
|
||||||
mock_get_current_vllm_config.return_value = MagicMock()
|
|
||||||
|
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.moe_comm_method = "mc2"
|
mock_context.moe_comm_method = "mc2"
|
||||||
@@ -124,7 +114,6 @@ class TestMoECommMethod(TestBase):
|
|||||||
context_metadata=context_metadata)
|
context_metadata=context_metadata)
|
||||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
||||||
@@ -134,11 +123,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
)
|
)
|
||||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||||
mock_prepare_finalize,
|
mock_prepare_finalize,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context):
|
||||||
mock_get_current_vllm_config):
|
|
||||||
# Mock vLLM config
|
|
||||||
mock_get_current_vllm_config.return_value = MagicMock()
|
|
||||||
|
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.moe_comm_method = "alltoall"
|
mock_context.moe_comm_method = "alltoall"
|
||||||
@@ -168,7 +153,6 @@ class TestMoECommMethod(TestBase):
|
|||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, QuantType.NONE)
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||||
@@ -179,11 +163,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
||||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||||
mock_token_dispatcher, mock_prepare_finalize,
|
mock_token_dispatcher, mock_prepare_finalize,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context):
|
||||||
mock_get_current_vllm_config):
|
|
||||||
# Mock vLLM config
|
|
||||||
mock_get_current_vllm_config.return_value = MagicMock()
|
|
||||||
|
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.moe_comm_method = "all_gather"
|
mock_context.moe_comm_method = "all_gather"
|
||||||
|
|||||||
@@ -240,7 +240,6 @@ def select_moe_comm_method(num_tokens: int,
|
|||||||
quant_type = getattr(
|
quant_type = getattr(
|
||||||
vllm_config.model_config.hf_config, 'moe_quantize',
|
vllm_config.model_config.hf_config, 'moe_quantize',
|
||||||
getattr(vllm_config.model_config.hf_config, 'quantize', None))
|
getattr(vllm_config.model_config.hf_config, 'quantize', None))
|
||||||
model_type = vllm_config.model_config.hf_config.model_type
|
|
||||||
|
|
||||||
if not vllm_config.parallel_config.enable_expert_parallel:
|
if not vllm_config.parallel_config.enable_expert_parallel:
|
||||||
moe_comm_type = MoECommType.ALLGATHER
|
moe_comm_type = MoECommType.ALLGATHER
|
||||||
@@ -267,7 +266,4 @@ def select_moe_comm_method(num_tokens: int,
|
|||||||
if fused_all2all_enable else MoECommType.ALLTOALL)
|
if fused_all2all_enable else MoECommType.ALLTOALL)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||||
# PanguProMoE only supports allgather
|
|
||||||
if model_type == "PanguProMoE":
|
|
||||||
moe_comm_type = MoECommType.ALLGATHER
|
|
||||||
return moe_comm_type
|
return moe_comm_type
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
@@ -30,7 +29,7 @@ from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
|||||||
PrepareAndFinalizeWithMC2, QuantType)
|
PrepareAndFinalizeWithMC2, QuantType)
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||||
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
TokenDispatcherWithMC2)
|
||||||
|
|
||||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||||
|
|
||||||
@@ -52,8 +51,6 @@ class MoECommMethod(ABC):
|
|||||||
"""Base class for MoE communication methods."""
|
"""Base class for MoE communication methods."""
|
||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
self.model_type = get_current_vllm_config(
|
|
||||||
).model_config.hf_config.model_type
|
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
|
|
||||||
self.token_dispatcher = self._get_token_dispatcher()
|
self.token_dispatcher = self._get_token_dispatcher()
|
||||||
@@ -198,12 +195,6 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_token_dispatcher(self):
|
def _get_token_dispatcher(self):
|
||||||
if self.model_type == "PanguProMoE":
|
|
||||||
return TokenDispatcherWithMoge(
|
|
||||||
top_k=self.moe_config.experts_per_token,
|
|
||||||
num_experts=self.moe_config.num_experts,
|
|
||||||
num_local_experts=self.moe_config.num_local_experts)
|
|
||||||
else:
|
|
||||||
return TokenDispatcherWithAllGather(
|
return TokenDispatcherWithAllGather(
|
||||||
top_k=self.moe_config.experts_per_token,
|
top_k=self.moe_config.experts_per_token,
|
||||||
num_experts=self.moe_config.num_experts,
|
num_experts=self.moe_config.num_experts,
|
||||||
|
|||||||
@@ -422,69 +422,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
# mypy: disable-error-code="override"
|
|
||||||
class TokenDispatcherWithMoge(MoETokenDispatcher):
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.apply_router_weight_on_input = False
|
|
||||||
self.local_num_experts = self.num_experts // self.ep_size
|
|
||||||
self.local_num_group = self.top_k // self.ep_size
|
|
||||||
self.bsz = None
|
|
||||||
|
|
||||||
def token_dispatch(self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
log2phy: Optional[torch.Tensor] = None,
|
|
||||||
global_redundant_expert_num: int = 0,
|
|
||||||
shared_experts: Optional[Any] = None,
|
|
||||||
quantized_x_for_share: Optional[Any] = None,
|
|
||||||
dynamic_scale_for_share: Optional[Any] = None,
|
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
with_quant: bool = False,
|
|
||||||
dynamic_eplb: bool = False,
|
|
||||||
pertoken_scale: Optional[torch.Tensor] = None):
|
|
||||||
self.bsz, _ = hidden_states.shape
|
|
||||||
flatten_topk_ids = topk_ids.view(-1)
|
|
||||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
|
||||||
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
|
|
||||||
sorted_hidden_states = hidden_states.index_select(
|
|
||||||
0, self.sorted_topk_ids // self.local_num_group)
|
|
||||||
|
|
||||||
experts_id = torch.arange(0,
|
|
||||||
self.local_num_experts,
|
|
||||||
dtype=topk_ids.dtype,
|
|
||||||
device=topk_ids.device)
|
|
||||||
num_tokens_per_expert = (
|
|
||||||
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
|
||||||
torch.float32).sum(0)
|
|
||||||
topk_scales = topk_weights.view(-1).index_select(
|
|
||||||
0, self.sorted_topk_ids).unsqueeze(-1)
|
|
||||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
|
||||||
group_list_type = 0
|
|
||||||
return {
|
|
||||||
"group_list_type": group_list_type,
|
|
||||||
"hidden_states": sorted_hidden_states,
|
|
||||||
"group_list": group_list,
|
|
||||||
"topk_scales": topk_scales
|
|
||||||
}
|
|
||||||
|
|
||||||
def token_combine(self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
context_metadata: dict,
|
|
||||||
bias: torch.Tensor = None):
|
|
||||||
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
|
|
||||||
torch.int32)
|
|
||||||
unsorted_hidden_states = hidden_states.index_select(
|
|
||||||
0, unsorted_topk_ids)
|
|
||||||
final_hidden_states = unsorted_hidden_states.reshape(
|
|
||||||
self.bsz, self.top_k // self.ep_size, -1).sum(1)
|
|
||||||
return final_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||||
"""
|
"""
|
||||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||||
|
|||||||
Reference in New Issue
Block a user