[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -19,6 +19,38 @@ import torch
from tests.ut.base import TestBase
from vllm_ascend._310p.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEMlpComputeInput,
MoEQuantParams,
MoEWeights,
)
from vllm_ascend.quantization.quant_type import QuantType
def build_mlp_compute_input_fixture(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
with_quant: bool,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
group_list_type: int = 1,
) -> MoEMlpComputeInput:
return MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=group_list,
group_list_type=group_list_type,
dynamic_scale=None,
topk_scales=None,
weights=MoEWeights(w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale),
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
fusion=False,
activation="silu",
need_trans=False,
dynamic_eplb=False,
)
class TestUnifiedApplyMLP310(TestBase):
@@ -38,14 +70,13 @@ class TestUnifiedApplyMLP310(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(
hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
group_list_type=1,
with_quant=False,
mlp_compute_input=build_mlp_compute_input_fixture(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
with_quant=False,
)
)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
@@ -94,14 +125,15 @@ class TestUnifiedApplyMLP310(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
group_list_type=1,
with_quant=True,
mlp_compute_input=build_mlp_compute_input_fixture(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
with_quant=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
)
mock_cumsum.assert_called_once()