Files
xc-llm-ascend/tests/ut/quantization/test_w8a8_dynamic.py
wangxiyuan f10acddb78 drop ascend scheduler (#4498)
Ascend scheduler was added for non chunk prefill case before, since that
the npu ops didn't work well with chunked prefill.

Now the ops with chunked prefill work better, it's time to remove the
ascend scheduler to use vLLM default scheduler.

- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-11-29 16:18:34 +08:00

63 lines
2.7 KiB
Python

from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
class TestAscendW8A8FusedMoEMethod(TestBase):
num_experts = 8
hidden_size = 128
intermediate_size = 128
@patch("torch.distributed.get_rank")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
mock_get_mc2_group, mock_get_rank):
with patch(
'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
) as mock_get_current_vllm_config:
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 256})
mock_vllm_config.scheduler_config = Mock(
max_num_batched_tokens=2048,
max_model_len=2048,
enable_chunked_prefill=False)
mock_get_current_vllm_config.return_value = mock_vllm_config
mock_ep_group = Mock()
mock_get_ep_group.return_value = mock_ep_group
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_ascend_config.enable_chunked_prefill = False
mock_get_ascend_config.return_value = mock_ascend_config
mock_mc2_group = Mock(device_group=0)
mock_get_mc2_group.return_value = mock_mc2_group
mock_rank = Mock()
mock_get_rank.return_value = mock_rank
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
def test_get_weight(self):
param_dict = self.quant_method.get_weight(self.num_experts,
self.intermediate_size,
self.hidden_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(
param_dict["w13_weight"].shape,
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
def test_get_dynamic_quant_param(self):
param_dict = self.quant_method.get_dynamic_quant_param(
self.num_experts, self.intermediate_size, self.hidden_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.num_experts, 2 * self.intermediate_size, 1))