Files
xc-llm-ascend/tests/ut/quantization/test_w8a16.py
TmacAaron 5018f2d8fd [quantization] Add w8a16 quantization support (#4541)
### What this PR does / why we need it?
related to https://github.com/vllm-project/vllm-ascend/issues/4267

### Does this PR introduce _any_ user-facing change?
support w8a16 quantization now

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

### Test
tested using [aisbench](https://gitee.com/aisbench/benchmark/) with tp2
#### Precision
  | ceval | mmlu | gsm8k
-- | -- | -- | --
bf16 | 90.46 | 89.17 | 96.21
w8a16 | 89.51 | 89.29 | 95.98

#### Performance
  | input_len | output_len | concurrency | TTFT (ms) | TPOT (ms) | TPS
(Total) (tokens/s)
-- | -- | -- | -- | -- | -- | --
bf16 | 2048 | 2048 | 10 | 1911.7136 | 77.988 | 253.9866
w8a16 | 2048 | 2048 | 10 | 2128.6334 | 67.1633 | 293.9117
bf16 | 3500 | 1024 | 10 | 3076.2509 | 84.3525 | 506.949
w8a16 | 3500 | 1024 | 10 | 2685.2031 | 73.015 | 585.4717

---------

Signed-off-by: yyt <yangyit139@gmail.com>
Signed-off-by: TmacAaron <yangyit139@gmail.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
2025-12-24 19:49:32 +08:00

92 lines
3.6 KiB
Python

import os
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod
class TestAscendW8A16LinearMethod(TestBase):
def setUp(self):
self.method = AscendW8A16LinearMethod()
def test_get_weight(self):
weight = self.method.get_weight(10, 20)
self.assertEqual(weight['weight'].dtype, torch.int8)
self.assertEqual(weight['weight'].shape, (20, 10))
@patch("torch_npu.npu_weight_quant_batchmatmul")
def test_apply_with_x_is_int8(self, mock_npu_weight_quant_batchmatmul):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
x = torch.randn(32, 128)
bias = torch.randn(256)
expected_y_output = torch.randn(32, 256)
mock_npu_weight_quant_batchmatmul.return_value = expected_y_output
output = self.method.apply(layer, x, bias)
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_with_nz0(self,
mock_npu_format_cast):
layer = MagicMock()
layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_not_called()
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_with_nz1(self,
mock_npu_format_cast):
layer = MagicMock()
layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_called_once()
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_with_nz2(self,
mock_npu_format_cast):
layer = MagicMock()
layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_called_once()