[quantization] Add w8a16 quantization support (#4541)
### What this PR does / why we need it?
related to https://github.com/vllm-project/vllm-ascend/issues/4267
### Does this PR introduce _any_ user-facing change?
support w8a16 quantization now
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
### Test
tested using [aisbench](https://gitee.com/aisbench/benchmark/) with tp2
#### Precision
| ceval | mmlu | gsm8k
-- | -- | -- | --
bf16 | 90.46 | 89.17 | 96.21
w8a16 | 89.51 | 89.29 | 95.98
#### Performance
| input_len | output_len | concurrency | TTFT (ms) | TPOT (ms) | TPS
(Total) (tokens/s)
-- | -- | -- | -- | -- | -- | --
bf16 | 2048 | 2048 | 10 | 1911.7136 | 77.988 | 253.9866
w8a16 | 2048 | 2048 | 10 | 2128.6334 | 67.1633 | 293.9117
bf16 | 3500 | 1024 | 10 | 3076.2509 | 84.3525 | 506.949
w8a16 | 3500 | 1024 | 10 | 2685.2031 | 73.015 | 585.4717
---------
Signed-off-by: yyt <yangyit139@gmail.com>
Signed-off-by: TmacAaron <yangyit139@gmail.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
from modelscope import snapshot_download # type: ignore[import-untyped]
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.model_utils import check_outputs_equal
|
||||
|
||||
|
||||
def test_qwen3_w8a8_quant():
|
||||
@@ -25,10 +26,53 @@ def test_qwen3_w8a8_quant():
|
||||
example_prompts = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
|
||||
]
|
||||
vllm_target_outputs = [([
|
||||
85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323,
|
||||
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387
|
||||
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be'
|
||||
)]
|
||||
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/Qwen3-0.6B-W8A8"),
|
||||
max_model_len=8192,
|
||||
gpu_memory_utilization=0.7,
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
vllm_quant_w8a8_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_target_outputs,
|
||||
outputs_1_lst=vllm_quant_w8a8_outputs,
|
||||
name_0="vllm_target_outputs",
|
||||
name_1="vllm_w8a16_outputs",
|
||||
)
|
||||
|
||||
|
||||
def test_qwen3_dense_w8a16():
|
||||
max_tokens = 5
|
||||
example_prompts = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
|
||||
]
|
||||
vllm_target_outputs = [([
|
||||
85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323,
|
||||
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387
|
||||
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be'
|
||||
)]
|
||||
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/Qwen3-0.6B-W8A16"),
|
||||
max_model_len=8192,
|
||||
enforce_eager=False,
|
||||
gpu_memory_utilization=0.7,
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_quant_w8a16_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_target_outputs,
|
||||
outputs_1_lst=vllm_quant_w8a16_outputs,
|
||||
name_0="vllm_target_outputs",
|
||||
name_1="vllm_w8a16_outputs",
|
||||
)
|
||||
|
||||
91
tests/ut/quantization/test_w8a16.py
Normal file
91
tests/ut/quantization/test_w8a16.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod
|
||||
|
||||
|
||||
class TestAscendW8A16LinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendW8A16LinearMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(10, 20)
|
||||
self.assertEqual(weight['weight'].dtype, torch.int8)
|
||||
self.assertEqual(weight['weight'].shape, (20, 10))
|
||||
|
||||
@patch("torch_npu.npu_weight_quant_batchmatmul")
|
||||
def test_apply_with_x_is_int8(self, mock_npu_weight_quant_batchmatmul):
|
||||
layer = MagicMock()
|
||||
layer.weight.data = torch.randn(128, 256)
|
||||
layer.weight_scale.data = torch.randn(128, 1)
|
||||
layer.weight_offset.data = torch.randn(128, 1)
|
||||
|
||||
x = torch.randn(32, 128)
|
||||
bias = torch.randn(256)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_weight_quant_batchmatmul.return_value = expected_y_output
|
||||
|
||||
output = self.method.apply(layer, x, bias)
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading_with_nz0(self,
|
||||
mock_npu_format_cast):
|
||||
layer = MagicMock()
|
||||
layer.weight.data = torch.randint(-127,
|
||||
128, (128, 256),
|
||||
dtype=torch.int8)
|
||||
layer.weight_scale.data = torch.randn(128, 1)
|
||||
layer.weight_offset.data = torch.randn(128, 1)
|
||||
|
||||
mock_npu_format_cast.return_value = MagicMock
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
||||
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
||||
mock_npu_format_cast.assert_not_called()
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"})
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading_with_nz1(self,
|
||||
mock_npu_format_cast):
|
||||
layer = MagicMock()
|
||||
|
||||
layer.weight.data = torch.randint(-127,
|
||||
128, (128, 256),
|
||||
dtype=torch.int8)
|
||||
layer.weight_scale.data = torch.randn(128, 1)
|
||||
layer.weight_offset.data = torch.randn(128, 1)
|
||||
|
||||
mock_npu_format_cast.return_value = MagicMock
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
||||
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
||||
mock_npu_format_cast.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"})
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading_with_nz2(self,
|
||||
mock_npu_format_cast):
|
||||
layer = MagicMock()
|
||||
|
||||
layer.weight.data = torch.randint(-127,
|
||||
128, (128, 256),
|
||||
dtype=torch.int8)
|
||||
layer.weight_scale.data = torch.randn(128, 1)
|
||||
layer.weight_offset.data = torch.randn(128, 1)
|
||||
|
||||
mock_npu_format_cast.return_value = MagicMock
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
||||
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
||||
mock_npu_format_cast.assert_called_once()
|
||||
@@ -14,6 +14,7 @@ from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
|
||||
AscendW8A8PDMixLinearMethod)
|
||||
from .w8a16 import AscendW8A16LinearMethod
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W4A16": {
|
||||
@@ -36,6 +37,9 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W8A8_MIX": {
|
||||
"linear": AscendW8A8PDMixLinearMethod,
|
||||
"moe": AscendW8A8PDMixFusedMoeMethod,
|
||||
},
|
||||
"W8A16": {
|
||||
"linear": AscendW8A16LinearMethod,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
89
vllm_ascend/quantization/w8a16.py
Normal file
89
vllm_ascend/quantization/w8a16.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
|
||||
|
||||
class AscendW8A16LinearMethod:
|
||||
"""Linear method for Ascend W8A16.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
output = torch_npu.npu_weight_quant_batchmatmul(
|
||||
x=x,
|
||||
weight=layer.weight,
|
||||
antiquant_scale=layer.weight_scale,
|
||||
antiquant_offset=layer.weight_offset,
|
||||
bias=bias)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
Reference in New Issue
Block a user