Files
xc-llm-ascend/tests/ut/ops/test_moe_mlp.py
Clorist33 9045843c90 [UT]Ut for function cumsum_group_list in moe_mlp (ref #5025) (#5036)
### What this PR does / why we need it?
Add ut for the cumsum_group_list function, which is related to the
precision issues stemming from the moe_mlp.py .
The ralated PR is https://github.com/vllm-project/vllm-ascend/pull/5025

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: tanqingshan (A)  <50050625@china.huawei.com>
Co-authored-by: tanqingshan (A) <50050625@china.huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-18 15:00:16 +08:00

52 lines
2.0 KiB
Python

import unittest
from typing import ClassVar
import torch
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list
class TestCumsumGroupList(unittest.TestCase):
glist_dict: ClassVar[dict[int, torch.Tensor]]
@classmethod
def setUpClass(cls):
cls.glist_dict = {
0: torch.tensor([0, 2, 3, 3]),
1: torch.tensor([0, 2, 1, 0]),
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]])
}
support_combine = [(0, 0), (1, 0), (0, 1)]
unsupport_combine = [(0, 2), (2, 1), (1, 2)]
def test_cumsum_group_list_supported_conversion(self):
for src_list_type, dst_list_type in self.support_combine:
with self.subTest(src=src_list_type, dst=dst_list_type):
result = cumsum_group_list(self.glist_dict[src_list_type],
src_list_type,
dst_list_type,
expert_num=4)
self.assertTrue(
torch.equal(result, self.glist_dict[dst_list_type]))
def test_cumsum_group_list_invalid_type_valueerror(self):
with self.assertRaises(ValueError) as excinfo:
cumsum_group_list(self.glist_dict[0], 4, 0)
self.assertIn("group_list_type should be in [0, 1, 2], but received",
str(excinfo.exception))
def test_cumsum_group_list_unsupported_conversion_notimplementederror(
self):
for src_list_type, dst_list_type in self.unsupport_combine:
with self.subTest(src=src_list_type, dst=dst_list_type):
with self.assertRaises(NotImplementedError) as excinfo:
cumsum_group_list(self.glist_dict[0], src_list_type,
dst_list_type)
self.assertIn("This feature is under development.",
str(excinfo.exception))
if __name__ == '__main__':
unittest.main(verbosity=2)