[main][Feature] Support Qwen3 W4A8 quantization (#2060)
### What this PR does / why we need it?
Adding `W4A8_DYNAMIC` quantization support for linear.
Dense models like Qwen3 can infer with `W4A8_DYNAMIC` quantization.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py`
Adding e2e case in
`tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC`
to test qwen3 w4a8_dynamic quantized model
Note the w4a8_dynamic quantized model is quantized by `msit/msmodelslim`
of commit `d0abb0a47e1f1a473b866ad41b737fbc28fb1409`
1. Generate `W4A8_DYNAMIC` quantization weights using `msmodelslim`
```shell
git clone https://gitee.com/ascend/msit.git
cd msit/msmodelslim
git checkout d0abb0a47e1f1a473b866ad41b737fbc28fb1409
bash install.sh
```
2. Serve model using `vllm`
```shell
VLLM_USE_V1=1 python -m vllm.entrypoints.openai.api_server \
--model vllm-ascend/Qwen3-8B-W4A8 \
--port 8000 \
--quantization ascend \
--tensor_parallel_size 2 \
--enforce-eager
```
- vLLM version: v0.10.0
- vLLM main:
4cd7fe6cea
---------
Signed-off-by: ZhouXiang <zhouxiang100@huawei.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test.yaml
vendored
1
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -278,6 +278,7 @@ jobs:
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
|
||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
|
||||
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
|
||||
|
||||
@@ -166,7 +166,22 @@ def test_models_distributed_Qwen3_W8A8():
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/Qwen3-8B-W8A8"),
|
||||
max_model_len=8192,
|
||||
enforce_eager=True,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_W4A8DYNAMIC():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
max_tokens = 5
|
||||
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/Qwen3-8B-W4A8"),
|
||||
max_model_len=8192,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
quantization="ascend",
|
||||
|
||||
27
tests/ut/quantization/test_w4a8_dynamic.py
Normal file
27
tests/ut/quantization/test_w4a8_dynamic.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w4a8_dynamic import AscendW4A8DynamicLinearMethod
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
self.assertEqual(weight["weight"].dtype, torch.int8)
|
||||
self.assertEqual(weight["weight"].shape, (32, 8))
|
||||
|
||||
def test_get_pergroup_param(self):
|
||||
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
|
||||
@@ -205,6 +205,17 @@ class AscendLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter(perchannel_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pergroup_dict = self.quant_method.get_pergroup_param(
|
||||
input_size_per_partition, output_size_per_partition, params_dtype)
|
||||
for pergroup_name, pergroup_param in pergroup_dict.items():
|
||||
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(pergroup_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
|
||||
setattr(param, "input_dim", 1)
|
||||
param.input_dim = 1
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional
|
||||
from vllm.logger import logger
|
||||
|
||||
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
|
||||
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
@@ -263,6 +264,13 @@ class VLLMAscendQuantizer:
|
||||
f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}")
|
||||
|
||||
|
||||
class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
||||
|
||||
@staticmethod
|
||||
def build_linear_method():
|
||||
return AscendW4A8DynamicLinearMethod()
|
||||
|
||||
|
||||
class W8A8Quantizer(VLLMAscendQuantizer):
|
||||
|
||||
@staticmethod
|
||||
@@ -290,6 +298,7 @@ class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
||||
|
||||
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {
|
||||
"W4A8_DYNAMIC": W4A8DYNAMICQuantizer,
|
||||
"W8A8": W8A8Quantizer,
|
||||
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
|
||||
"C8": W8A8Quantizer,
|
||||
|
||||
113
vllm_ascend/quantization/w4a8_dynamic.py
Normal file
113
vllm_ascend/quantization/w4a8_dynamic.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A8_DYNAMIC
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
try:
|
||||
self.group_size = get_current_vllm_config(
|
||||
).quant_config.quant_description.get("group_size", 256)
|
||||
except AttributeError:
|
||||
self.group_size = 256
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_scale_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
|
||||
per_group_scale: torch.Tensor):
|
||||
k, n = weight.shape
|
||||
group_num, n = per_group_scale.shape
|
||||
weight_high = weight.to(torch.float32).reshape(
|
||||
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
||||
weight_high = weight_high.reshape(k, n)
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
||||
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
||||
return antiquant_scale.npu(), bias
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_weight_quant_batchmatmul(
|
||||
x,
|
||||
layer.weight,
|
||||
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
||||
antiquant_group_size=self.group_size,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
||||
torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
||||
layer.weight.data,
|
||||
layer.weight_scale.data,
|
||||
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
||||
)
|
||||
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
||||
layer.register_parameter("weight_scale_bias", param)
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||
layer.weight.data.to(torch.int32))
|
||||
@@ -84,6 +84,10 @@ class AscendW8A8LinearMethod:
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@@ -699,6 +699,10 @@ class AscendW8A8DynamicLinearMethod:
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
|
||||
Reference in New Issue
Block a user