### What this PR does / why we need it?
Add ut for the cumsum_group_list function, which is related to the
precision issues stemming from the moe_mlp.py .
The ralated PR is https://github.com/vllm-project/vllm-ascend/pull/5025
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: tanqingshan (A) <50050625@china.huawei.com>
Co-authored-by: tanqingshan (A) <50050625@china.huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
52 lines
2.0 KiB
Python
52 lines
2.0 KiB
Python
import unittest
|
|
from typing import ClassVar
|
|
|
|
import torch
|
|
|
|
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list
|
|
|
|
|
|
class TestCumsumGroupList(unittest.TestCase):
|
|
glist_dict: ClassVar[dict[int, torch.Tensor]]
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.glist_dict = {
|
|
0: torch.tensor([0, 2, 3, 3]),
|
|
1: torch.tensor([0, 2, 1, 0]),
|
|
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]])
|
|
}
|
|
|
|
support_combine = [(0, 0), (1, 0), (0, 1)]
|
|
unsupport_combine = [(0, 2), (2, 1), (1, 2)]
|
|
|
|
def test_cumsum_group_list_supported_conversion(self):
|
|
for src_list_type, dst_list_type in self.support_combine:
|
|
with self.subTest(src=src_list_type, dst=dst_list_type):
|
|
result = cumsum_group_list(self.glist_dict[src_list_type],
|
|
src_list_type,
|
|
dst_list_type,
|
|
expert_num=4)
|
|
self.assertTrue(
|
|
torch.equal(result, self.glist_dict[dst_list_type]))
|
|
|
|
def test_cumsum_group_list_invalid_type_valueerror(self):
|
|
with self.assertRaises(ValueError) as excinfo:
|
|
cumsum_group_list(self.glist_dict[0], 4, 0)
|
|
self.assertIn("group_list_type should be in [0, 1, 2], but received",
|
|
str(excinfo.exception))
|
|
|
|
def test_cumsum_group_list_unsupported_conversion_notimplementederror(
|
|
self):
|
|
for src_list_type, dst_list_type in self.unsupport_combine:
|
|
with self.subTest(src=src_list_type, dst=dst_list_type):
|
|
with self.assertRaises(NotImplementedError) as excinfo:
|
|
cumsum_group_list(self.glist_dict[0], src_list_type,
|
|
dst_list_type)
|
|
self.assertIn("This feature is under development.",
|
|
str(excinfo.exception))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2)
|