Files
xc-llm-ascend/tests/ut/ops/test_moe_mlp.py

132 lines
5.0 KiB
Python
Raw Permalink Normal View History

import unittest
from typing import ClassVar
from unittest.mock import patch
import torch
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEMlpComputeInput,
MoEQuantParams,
MoEWeights,
)
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
from vllm_ascend.quantization.quant_type import QuantType
class TestCumsumGroupList(unittest.TestCase):
glist_dict: ClassVar[dict[int, torch.Tensor]]
@classmethod
def setUpClass(cls):
cls.glist_dict = {
0: torch.tensor([0, 2, 3, 3]),
1: torch.tensor([0, 2, 1, 0]),
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]),
}
support_combine = [(0, 0), (1, 0), (0, 1)]
unsupported_combine = [(0, 2), (2, 1), (1, 2)]
def test_cumsum_group_list_supported_conversion(self):
for src_list_type, dst_list_type in self.support_combine:
with self.subTest(src=src_list_type, dst=dst_list_type):
result = cumsum_group_list(self.glist_dict[src_list_type], src_list_type, dst_list_type, expert_num=4)
self.assertTrue(torch.equal(result, self.glist_dict[dst_list_type]))
def test_cumsum_group_list_invalid_type_valueerror(self):
with self.assertRaises(ValueError) as excinfo:
cumsum_group_list(self.glist_dict[0], 4, 0)
self.assertIn("group_list_type should be in [0, 1, 2], but received", str(excinfo.exception))
def test_cumsum_group_list_unsupported_conversion_notimplementederror(self):
for src_list_type, dst_list_type in self.unsupported_combine:
with self.subTest(src=src_list_type, dst=dst_list_type):
with self.assertRaises(NotImplementedError) as excinfo:
cumsum_group_list(self.glist_dict[0], src_list_type, dst_list_type)
self.assertIn("This feature is under development.", str(excinfo.exception))
class TestUnifiedApplyMlpRequest(unittest.TestCase):
def test_request_unquant_path(self):
hidden_states = torch.randn(2, 8)
expected = torch.randn(2, 8)
mlp_compute_input = MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=torch.tensor([2, 2], dtype=torch.int64),
group_list_type=1,
dynamic_scale=None,
topk_scales=None,
weights=MoEWeights(
w1=torch.randn(1, 16, 8),
w2=torch.randn(1, 8, 8),
w1_bias=torch.randn(1, 16),
w2_bias=torch.randn(1, 8),
),
quant=MoEQuantParams(quant_type=QuantType.NONE),
fusion=False,
activation="silu",
need_trans=False,
dynamic_eplb=False,
)
with (
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp", return_value=expected) as mock_unquant,
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp") as mock_quant,
):
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
self.assertTrue(output is expected)
mock_unquant.assert_called_once()
self.assertEqual(mock_unquant.call_args.kwargs["activation"], "silu")
self.assertFalse(mock_unquant.call_args.kwargs["need_trans"])
mock_quant.assert_not_called()
def test_request_quant_path(self):
hidden_states = torch.randn(2, 8)
expected = torch.randn(2, 8)
mlp_compute_input = MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=torch.tensor([2, 2], dtype=torch.int64),
group_list_type=1,
dynamic_scale=torch.randn(2, 1),
topk_scales=None,
weights=MoEWeights(
w1=torch.randn(1, 16, 8),
w2=torch.randn(1, 8, 8),
w1_scale=[torch.randn(1)],
w2_scale=[torch.randn(1)],
),
quant=MoEQuantParams(
quant_type=QuantType.MXFP8,
mxfp=MoEMxfpParams(
act_quant_type=torch.float8_e4m3fn,
weight_quant_type=torch.float8_e4m3fn,
use_bf16=False,
),
),
fusion=True,
activation="silu",
need_trans=False,
dynamic_eplb=True,
)
with (
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp", return_value=expected) as mock_quant,
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp") as mock_unquant,
):
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
self.assertTrue(output is expected)
mock_quant.assert_called_once()
quant_kwargs = mock_quant.call_args.kwargs
self.assertTrue(quant_kwargs["use_mxfp_quant"])
self.assertTrue(quant_kwargs["fusion"])
self.assertTrue(quant_kwargs["dynamic_eplb"])
self.assertFalse(quant_kwargs["use_bf16"])
mock_unquant.assert_not_called()
if __name__ == "__main__":
unittest.main(verbosity=2)