[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, TypedDict
|
||||
from typing import TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -20,12 +20,19 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
||||
unified_apply_mlp)
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEMlpComputeInput,
|
||||
MoEPrepareOutput,
|
||||
MoEQuantParams,
|
||||
MoEWeights,
|
||||
)
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
from vllm_ascend.utils import AscendDeviceType, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -54,6 +61,51 @@ def mock_npu_format_cast(weight_data, format):
|
||||
return weight_data
|
||||
|
||||
|
||||
def build_mlp_compute_input_fixture(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
group_list: torch.Tensor,
|
||||
with_quant: bool,
|
||||
group_list_type: int = 1,
|
||||
dynamic_scale: torch.Tensor | None = None,
|
||||
topk_scales: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | list[torch.Tensor] | None = None,
|
||||
w2_scale: torch.Tensor | list[torch.Tensor] | None = None,
|
||||
w1_scale_bias: torch.Tensor | None = None,
|
||||
w2_scale_bias: torch.Tensor | None = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
fusion: bool = False,
|
||||
activation: str = "silu",
|
||||
need_trans: bool = True,
|
||||
dynamic_eplb: bool = False,
|
||||
) -> MoEMlpComputeInput:
|
||||
return MoEMlpComputeInput(
|
||||
hidden_states=hidden_states,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
dynamic_scale=dynamic_scale,
|
||||
topk_scales=topk_scales,
|
||||
weights=MoEWeights(
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
),
|
||||
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
|
||||
fusion=fusion,
|
||||
activation=activation,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_hf_config = MagicMock()
|
||||
@@ -77,7 +129,13 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
mock_moe_comm_method = MagicMock()
|
||||
|
||||
def mock_prepare(hidden_states, router_logits, **kwargs):
|
||||
return hidden_states, router_logits
|
||||
return MoEPrepareOutput(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
mc2_mask=kwargs.get("mc2_mask"),
|
||||
padded_hidden_states_shape=None,
|
||||
pertoken_scale=None,
|
||||
)
|
||||
|
||||
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
||||
|
||||
@@ -204,18 +262,18 @@ def moe_method(mock_dist_env):
|
||||
|
||||
class Device(TypedDict):
|
||||
device_id: int
|
||||
device_expert: List[int]
|
||||
device_expert: list[int]
|
||||
|
||||
|
||||
class Layer(TypedDict):
|
||||
layer_id: int
|
||||
device_count: int
|
||||
device_list: List[Device]
|
||||
device_list: list[Device]
|
||||
|
||||
|
||||
class MockData(TypedDict):
|
||||
moe_layer_count: int
|
||||
layer_list: List[Layer]
|
||||
layer_list: list[Layer]
|
||||
|
||||
|
||||
class MockQuantMethod(nn.Module):
|
||||
@@ -338,18 +396,15 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
))
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
|
||||
@@ -383,18 +438,14 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=False,
|
||||
topk_scales=topk_scales,
|
||||
))
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
@@ -445,18 +496,18 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=True,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
))
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
|
||||
@@ -490,18 +541,14 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=False,
|
||||
topk_scales=topk_scales,
|
||||
))
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
@@ -556,19 +603,19 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True,
|
||||
fusion=True)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=True,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
fusion=True,
|
||||
))
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
mock_npu_grouped_matmul.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user