### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
132 lines
5.0 KiB
Python
132 lines
5.0 KiB
Python
import unittest
|
|
from typing import ClassVar
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
|
MoEMlpComputeInput,
|
|
MoEQuantParams,
|
|
MoEWeights,
|
|
)
|
|
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
|
|
from vllm_ascend.quantization.quant_type import QuantType
|
|
|
|
|
|
class TestCumsumGroupList(unittest.TestCase):
|
|
glist_dict: ClassVar[dict[int, torch.Tensor]]
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.glist_dict = {
|
|
0: torch.tensor([0, 2, 3, 3]),
|
|
1: torch.tensor([0, 2, 1, 0]),
|
|
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]),
|
|
}
|
|
|
|
support_combine = [(0, 0), (1, 0), (0, 1)]
|
|
unsupported_combine = [(0, 2), (2, 1), (1, 2)]
|
|
|
|
def test_cumsum_group_list_supported_conversion(self):
|
|
for src_list_type, dst_list_type in self.support_combine:
|
|
with self.subTest(src=src_list_type, dst=dst_list_type):
|
|
result = cumsum_group_list(self.glist_dict[src_list_type], src_list_type, dst_list_type, expert_num=4)
|
|
self.assertTrue(torch.equal(result, self.glist_dict[dst_list_type]))
|
|
|
|
def test_cumsum_group_list_invalid_type_valueerror(self):
|
|
with self.assertRaises(ValueError) as excinfo:
|
|
cumsum_group_list(self.glist_dict[0], 4, 0)
|
|
self.assertIn("group_list_type should be in [0, 1, 2], but received", str(excinfo.exception))
|
|
|
|
def test_cumsum_group_list_unsupported_conversion_notimplementederror(self):
|
|
for src_list_type, dst_list_type in self.unsupported_combine:
|
|
with self.subTest(src=src_list_type, dst=dst_list_type):
|
|
with self.assertRaises(NotImplementedError) as excinfo:
|
|
cumsum_group_list(self.glist_dict[0], src_list_type, dst_list_type)
|
|
self.assertIn("This feature is under development.", str(excinfo.exception))
|
|
|
|
|
|
class TestUnifiedApplyMlpRequest(unittest.TestCase):
|
|
def test_request_unquant_path(self):
|
|
hidden_states = torch.randn(2, 8)
|
|
expected = torch.randn(2, 8)
|
|
mlp_compute_input = MoEMlpComputeInput(
|
|
hidden_states=hidden_states,
|
|
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
|
group_list_type=1,
|
|
dynamic_scale=None,
|
|
topk_scales=None,
|
|
weights=MoEWeights(
|
|
w1=torch.randn(1, 16, 8),
|
|
w2=torch.randn(1, 8, 8),
|
|
w1_bias=torch.randn(1, 16),
|
|
w2_bias=torch.randn(1, 8),
|
|
),
|
|
quant=MoEQuantParams(quant_type=QuantType.NONE),
|
|
fusion=False,
|
|
activation="silu",
|
|
need_trans=False,
|
|
dynamic_eplb=False,
|
|
)
|
|
|
|
with (
|
|
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp", return_value=expected) as mock_unquant,
|
|
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp") as mock_quant,
|
|
):
|
|
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
|
|
|
self.assertTrue(output is expected)
|
|
mock_unquant.assert_called_once()
|
|
self.assertEqual(mock_unquant.call_args.kwargs["activation"], "silu")
|
|
self.assertFalse(mock_unquant.call_args.kwargs["need_trans"])
|
|
mock_quant.assert_not_called()
|
|
|
|
def test_request_quant_path(self):
|
|
hidden_states = torch.randn(2, 8)
|
|
expected = torch.randn(2, 8)
|
|
mlp_compute_input = MoEMlpComputeInput(
|
|
hidden_states=hidden_states,
|
|
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
|
group_list_type=1,
|
|
dynamic_scale=torch.randn(2, 1),
|
|
topk_scales=None,
|
|
weights=MoEWeights(
|
|
w1=torch.randn(1, 16, 8),
|
|
w2=torch.randn(1, 8, 8),
|
|
w1_scale=[torch.randn(1)],
|
|
w2_scale=[torch.randn(1)],
|
|
),
|
|
quant=MoEQuantParams(
|
|
quant_type=QuantType.MXFP8,
|
|
mxfp=MoEMxfpParams(
|
|
act_quant_type=torch.float8_e4m3fn,
|
|
weight_quant_type=torch.float8_e4m3fn,
|
|
use_bf16=False,
|
|
),
|
|
),
|
|
fusion=True,
|
|
activation="silu",
|
|
need_trans=False,
|
|
dynamic_eplb=True,
|
|
)
|
|
|
|
with (
|
|
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp", return_value=expected) as mock_quant,
|
|
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp") as mock_unquant,
|
|
):
|
|
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
|
|
|
self.assertTrue(output is expected)
|
|
mock_quant.assert_called_once()
|
|
quant_kwargs = mock_quant.call_args.kwargs
|
|
self.assertTrue(quant_kwargs["use_mxfp_quant"])
|
|
self.assertTrue(quant_kwargs["fusion"])
|
|
self.assertTrue(quant_kwargs["dynamic_eplb"])
|
|
self.assertFalse(quant_kwargs["use_bf16"])
|
|
mock_unquant.assert_not_called()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|