From 4fcca137a70c11daa4070ae014288be154715939 Mon Sep 17 00:00:00 2001 From: Ruri <33858552+zhoux77899@users.noreply.github.com> Date: Wed, 30 Jul 2025 14:57:14 +0800 Subject: [PATCH] [main][Feature] Support Qwen3 W4A8 quantization (#2060) ### What this PR does / why we need it? Adding `W4A8_DYNAMIC` quantization support for linear. Dense models like Qwen3 can infer with `W4A8_DYNAMIC` quantization. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` Adding e2e case in `tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC` to test qwen3 w4a8_dynamic quantized model Note the w4a8_dynamic quantized model is quantized by `msit/msmodelslim` of commit `d0abb0a47e1f1a473b866ad41b737fbc28fb1409` 1. Generate `W4A8_DYNAMIC` quantization weights using `msmodelslim` ```shell git clone https://gitee.com/ascend/msit.git cd msit/msmodelslim git checkout d0abb0a47e1f1a473b866ad41b737fbc28fb1409 bash install.sh ``` 2. Serve model using `vllm` ```shell VLLM_USE_V1=1 python -m vllm.entrypoints.openai.api_server \ --model vllm-ascend/Qwen3-8B-W4A8 \ --port 8000 \ --quantization ascend \ --tensor_parallel_size 2 \ --enforce-eager ``` - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4cd7fe6ceaf5ad7d8ac2ba5597cd964c6db7e306 --------- Signed-off-by: ZhouXiang --- .github/workflows/vllm_ascend_test.yaml | 1 + .../test_offline_inference_distributed.py | 17 ++- tests/ut/quantization/test_w4a8_dynamic.py | 27 +++++ vllm_ascend/quantization/quant_config.py | 11 ++ vllm_ascend/quantization/quantizer.py | 9 ++ vllm_ascend/quantization/w4a8_dynamic.py | 113 ++++++++++++++++++ vllm_ascend/quantization/w8a8.py | 4 + vllm_ascend/quantization/w8a8_dynamic.py | 4 + 8 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 tests/ut/quantization/test_w4a8_dynamic.py create mode 100644 vllm_ascend/quantization/w4a8_dynamic.py diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 517acfc..769fc18 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -278,6 +278,7 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 224bf45..7ddd5c7 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -166,7 +166,22 @@ def test_models_distributed_Qwen3_W8A8(): with VllmRunner( snapshot_download("vllm-ascend/Qwen3-8B-W8A8"), max_model_len=8192, - enforce_eager=True, + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_Qwen3_W4A8DYNAMIC(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download("vllm-ascend/Qwen3-8B-W4A8"), + max_model_len=8192, dtype="auto", tensor_parallel_size=2, quantization="ascend", diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py new file mode 100644 index 0000000..445a0cd --- /dev/null +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -0,0 +1,27 @@ +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.w4a8_dynamic import AscendW4A8DynamicLinearMethod + + +class TestAscendW4A8DynamicLinearMethod(TestBase): + + def setUp(self): + self.method = AscendW4A8DynamicLinearMethod() + self.method.group_size = 8 + + def test_get_weight(self): + weight = self.method.get_weight(8, 32, torch.bfloat16) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (32, 8)) + + def test_get_pergroup_param(self): + params = self.method.get_pergroup_param(8, 32, torch.bfloat16) + self.assertEqual(params["weight_scale"].dtype, torch.bfloat16) + self.assertEqual(params["weight_scale"].shape, (32, 1)) + self.assertEqual(params["weight_offset"].dtype, torch.bfloat16) + self.assertEqual(params["weight_offset"].shape, (32, 1)) + self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16) + self.assertEqual(params["weight_scale_second"].shape, (32, 1)) + self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) + self.assertEqual(params["weight_offset_second"].shape, (32, 1)) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 5984dc7..2577b35 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -205,6 +205,17 @@ class AscendLinearMethod(LinearMethodBase): layer.register_parameter(perchannel_name, param) set_weight_attrs(param, extra_weight_attrs) + pergroup_dict = self.quant_method.get_pergroup_param( + input_size_per_partition, output_size_per_partition, params_dtype) + for pergroup_name, pergroup_param in pergroup_dict.items(): + param = torch.nn.Parameter(pergroup_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(pergroup_name, param) + set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name: + setattr(param, "input_dim", 1) + param.input_dim = 1 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 8178d5e..c0d2241 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional from vllm.logger import logger from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init +from .w4a8_dynamic import AscendW4A8DynamicLinearMethod from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod) from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, @@ -263,6 +264,13 @@ class VLLMAscendQuantizer: f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}") +class W4A8DYNAMICQuantizer(VLLMAscendQuantizer): + + @staticmethod + def build_linear_method(): + return AscendW4A8DynamicLinearMethod() + + class W8A8Quantizer(VLLMAscendQuantizer): @staticmethod @@ -290,6 +298,7 @@ class W8A8DYNAMICQuantizer(VLLMAscendQuantizer): SUPPORT_ASCEND_QUANTIZER_TYPE = { + "W4A8_DYNAMIC": W4A8DYNAMICQuantizer, "W8A8": W8A8Quantizer, "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, "C8": W8A8Quantizer, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py new file mode 100644 index 0000000..c37b1c4 --- /dev/null +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, Optional + +import torch +import torch_npu +from vllm.config import get_current_vllm_config + + +class AscendW4A8DynamicLinearMethod: + """Linear method for Ascend W4A8_DYNAMIC + """ + + def __init__(self): + self.transpose_weight = True + try: + self.group_size = get_current_vllm_config( + ).quant_config.quant_description.get("group_size", 256) + except AttributeError: + self.group_size = 256 + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param(output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_scale_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + params_dict["weight_offset_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + return params_dict + + @staticmethod + def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, + per_group_scale: torch.Tensor): + k, n = weight.shape + group_num, n = per_group_scale.shape + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale.npu(), bias + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = None, + ) -> torch.Tensor: + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten().to( + torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_scale_second.data, scale_bias = self.process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + ) + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 09080de..d3bff93 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -84,6 +84,10 @@ class AscendW8A8LinearMethod: dtype=params_dtype) return params_dict + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + @staticmethod def apply( layer: torch.nn.Module, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 36549e7..b20ffa3 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -699,6 +699,10 @@ class AscendW8A8DynamicLinearMethod: dtype=params_dtype) return params_dict + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + @staticmethod def apply( layer: torch.nn.Module,