Files
xc-llm-ascend/tests/ut/ops/test_moe_mlp.py
linfeng-yuan 88d03a783f [refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
2026-03-20 23:23:57 +08:00

132 lines
5.0 KiB
Python

import unittest
from typing import ClassVar
from unittest.mock import patch
import torch
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEMlpComputeInput,
MoEQuantParams,
MoEWeights,
)
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
from vllm_ascend.quantization.quant_type import QuantType
class TestCumsumGroupList(unittest.TestCase):
glist_dict: ClassVar[dict[int, torch.Tensor]]
@classmethod
def setUpClass(cls):
cls.glist_dict = {
0: torch.tensor([0, 2, 3, 3]),
1: torch.tensor([0, 2, 1, 0]),
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]),
}
support_combine = [(0, 0), (1, 0), (0, 1)]
unsupported_combine = [(0, 2), (2, 1), (1, 2)]
def test_cumsum_group_list_supported_conversion(self):
for src_list_type, dst_list_type in self.support_combine:
with self.subTest(src=src_list_type, dst=dst_list_type):
result = cumsum_group_list(self.glist_dict[src_list_type], src_list_type, dst_list_type, expert_num=4)
self.assertTrue(torch.equal(result, self.glist_dict[dst_list_type]))
def test_cumsum_group_list_invalid_type_valueerror(self):
with self.assertRaises(ValueError) as excinfo:
cumsum_group_list(self.glist_dict[0], 4, 0)
self.assertIn("group_list_type should be in [0, 1, 2], but received", str(excinfo.exception))
def test_cumsum_group_list_unsupported_conversion_notimplementederror(self):
for src_list_type, dst_list_type in self.unsupported_combine:
with self.subTest(src=src_list_type, dst=dst_list_type):
with self.assertRaises(NotImplementedError) as excinfo:
cumsum_group_list(self.glist_dict[0], src_list_type, dst_list_type)
self.assertIn("This feature is under development.", str(excinfo.exception))
class TestUnifiedApplyMlpRequest(unittest.TestCase):
def test_request_unquant_path(self):
hidden_states = torch.randn(2, 8)
expected = torch.randn(2, 8)
mlp_compute_input = MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=torch.tensor([2, 2], dtype=torch.int64),
group_list_type=1,
dynamic_scale=None,
topk_scales=None,
weights=MoEWeights(
w1=torch.randn(1, 16, 8),
w2=torch.randn(1, 8, 8),
w1_bias=torch.randn(1, 16),
w2_bias=torch.randn(1, 8),
),
quant=MoEQuantParams(quant_type=QuantType.NONE),
fusion=False,
activation="silu",
need_trans=False,
dynamic_eplb=False,
)
with (
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp", return_value=expected) as mock_unquant,
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp") as mock_quant,
):
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
self.assertTrue(output is expected)
mock_unquant.assert_called_once()
self.assertEqual(mock_unquant.call_args.kwargs["activation"], "silu")
self.assertFalse(mock_unquant.call_args.kwargs["need_trans"])
mock_quant.assert_not_called()
def test_request_quant_path(self):
hidden_states = torch.randn(2, 8)
expected = torch.randn(2, 8)
mlp_compute_input = MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=torch.tensor([2, 2], dtype=torch.int64),
group_list_type=1,
dynamic_scale=torch.randn(2, 1),
topk_scales=None,
weights=MoEWeights(
w1=torch.randn(1, 16, 8),
w2=torch.randn(1, 8, 8),
w1_scale=[torch.randn(1)],
w2_scale=[torch.randn(1)],
),
quant=MoEQuantParams(
quant_type=QuantType.MXFP8,
mxfp=MoEMxfpParams(
act_quant_type=torch.float8_e4m3fn,
weight_quant_type=torch.float8_e4m3fn,
use_bf16=False,
),
),
fusion=True,
activation="silu",
need_trans=False,
dynamic_eplb=True,
)
with (
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp", return_value=expected) as mock_quant,
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp") as mock_unquant,
):
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
self.assertTrue(output is expected)
mock_quant.assert_called_once()
quant_kwargs = mock_quant.call_args.kwargs
self.assertTrue(quant_kwargs["use_mxfp_quant"])
self.assertTrue(quant_kwargs["fusion"])
self.assertTrue(quant_kwargs["dynamic_eplb"])
self.assertFalse(quant_kwargs["use_bf16"])
mock_unquant.assert_not_called()
if __name__ == "__main__":
unittest.main(verbosity=2)