[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -19,6 +19,38 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend._310p.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEMlpComputeInput,
|
||||
MoEQuantParams,
|
||||
MoEWeights,
|
||||
)
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
|
||||
def build_mlp_compute_input_fixture(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
with_quant: bool,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
group_list_type: int = 1,
|
||||
) -> MoEMlpComputeInput:
|
||||
return MoEMlpComputeInput(
|
||||
hidden_states=hidden_states,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
dynamic_scale=None,
|
||||
topk_scales=None,
|
||||
weights=MoEWeights(w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale),
|
||||
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
|
||||
fusion=False,
|
||||
activation="silu",
|
||||
need_trans=False,
|
||||
dynamic_eplb=False,
|
||||
)
|
||||
|
||||
|
||||
class TestUnifiedApplyMLP310(TestBase):
|
||||
@@ -38,14 +70,13 @@ class TestUnifiedApplyMLP310(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
group_list_type=1,
|
||||
with_quant=False,
|
||||
mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=False,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
@@ -94,14 +125,15 @@ class TestUnifiedApplyMLP310(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
group_list_type=1,
|
||||
with_quant=True,
|
||||
mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
)
|
||||
|
||||
mock_cumsum.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user