[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -12,7 +12,7 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, TypedDict
from typing import TypedDict
from unittest.mock import MagicMock, patch
import pytest
@@ -20,12 +20,19 @@ import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
unified_apply_mlp)
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEMlpComputeInput,
MoEPrepareOutput,
MoEQuantParams,
MoEWeights,
)
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import AscendDeviceType, adapt_patch
adapt_patch(True)
@@ -54,6 +61,51 @@ def mock_npu_format_cast(weight_data, format):
return weight_data
def build_mlp_compute_input_fixture(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
with_quant: bool,
group_list_type: int = 1,
dynamic_scale: torch.Tensor | None = None,
topk_scales: torch.Tensor | None = None,
w1_scale: torch.Tensor | list[torch.Tensor] | None = None,
w2_scale: torch.Tensor | list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor | None = None,
w2_scale_bias: torch.Tensor | None = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
fusion: bool = False,
activation: str = "silu",
need_trans: bool = True,
dynamic_eplb: bool = False,
) -> MoEMlpComputeInput:
return MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=group_list,
group_list_type=group_list_type,
dynamic_scale=dynamic_scale,
topk_scales=topk_scales,
weights=MoEWeights(
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
),
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
fusion=fusion,
activation=activation,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb,
)
@pytest.fixture(autouse=True)
def setup_vllm_config_mock(mocker: MockerFixture):
mock_hf_config = MagicMock()
@@ -77,7 +129,13 @@ def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method = MagicMock()
def mock_prepare(hidden_states, router_logits, **kwargs):
return hidden_states, router_logits
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=kwargs.get("mc2_mask"),
padded_hidden_states_shape=None,
pertoken_scale=None,
)
mock_moe_comm_method.prepare.side_effect = mock_prepare
@@ -204,18 +262,18 @@ def moe_method(mock_dist_env):
class Device(TypedDict):
device_id: int
device_expert: List[int]
device_expert: list[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
device_list: list[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
layer_list: list[Layer]
class MockQuantMethod(nn.Module):
@@ -338,18 +396,15 @@ class TestUnifiedApplyMLP(TestBase):
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=None,
with_quant=True)
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
with_quant=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
))
mock_get_forward_context.assert_called()
@@ -383,18 +438,14 @@ class TestUnifiedApplyMLP(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
with_quant=False,
topk_scales=topk_scales,
))
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
@@ -445,18 +496,18 @@ class TestUnifiedApplyMLP(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None,
with_quant=True)
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
with_quant=True,
dynamic_scale=provided_dynamic_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
))
mock_get_forward_context.assert_called()
@@ -490,18 +541,14 @@ class TestUnifiedApplyMLP(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
with_quant=False,
topk_scales=topk_scales,
))
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
@@ -556,19 +603,19 @@ class TestUnifiedApplyMLP(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None,
with_quant=True,
fusion=True)
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
with_quant=True,
dynamic_scale=provided_dynamic_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
fusion=True,
))
mock_get_forward_context.assert_called()
mock_npu_grouped_matmul.assert_called_once()