diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 517acfc..769fc18 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -278,6 +278,7 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 224bf45..7ddd5c7 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -166,7 +166,22 @@ def test_models_distributed_Qwen3_W8A8(): with VllmRunner( snapshot_download("vllm-ascend/Qwen3-8B-W8A8"), max_model_len=8192, - enforce_eager=True, + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_Qwen3_W4A8DYNAMIC(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download("vllm-ascend/Qwen3-8B-W4A8"), + max_model_len=8192, dtype="auto", tensor_parallel_size=2, quantization="ascend", diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py new file mode 100644 index 0000000..445a0cd --- /dev/null +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -0,0 +1,27 @@ +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.w4a8_dynamic import AscendW4A8DynamicLinearMethod + + +class TestAscendW4A8DynamicLinearMethod(TestBase): + + def setUp(self): + self.method = AscendW4A8DynamicLinearMethod() + self.method.group_size = 8 + + def test_get_weight(self): + weight = self.method.get_weight(8, 32, torch.bfloat16) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (32, 8)) + + def test_get_pergroup_param(self): + params = self.method.get_pergroup_param(8, 32, torch.bfloat16) + self.assertEqual(params["weight_scale"].dtype, torch.bfloat16) + self.assertEqual(params["weight_scale"].shape, (32, 1)) + self.assertEqual(params["weight_offset"].dtype, torch.bfloat16) + self.assertEqual(params["weight_offset"].shape, (32, 1)) + self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16) + self.assertEqual(params["weight_scale_second"].shape, (32, 1)) + self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) + self.assertEqual(params["weight_offset_second"].shape, (32, 1)) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 5984dc7..2577b35 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -205,6 +205,17 @@ class AscendLinearMethod(LinearMethodBase): layer.register_parameter(perchannel_name, param) set_weight_attrs(param, extra_weight_attrs) + pergroup_dict = self.quant_method.get_pergroup_param( + input_size_per_partition, output_size_per_partition, params_dtype) + for pergroup_name, pergroup_param in pergroup_dict.items(): + param = torch.nn.Parameter(pergroup_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(pergroup_name, param) + set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name: + setattr(param, "input_dim", 1) + param.input_dim = 1 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 8178d5e..c0d2241 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional from vllm.logger import logger from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init +from .w4a8_dynamic import AscendW4A8DynamicLinearMethod from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod) from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, @@ -263,6 +264,13 @@ class VLLMAscendQuantizer: f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}") +class W4A8DYNAMICQuantizer(VLLMAscendQuantizer): + + @staticmethod + def build_linear_method(): + return AscendW4A8DynamicLinearMethod() + + class W8A8Quantizer(VLLMAscendQuantizer): @staticmethod @@ -290,6 +298,7 @@ class W8A8DYNAMICQuantizer(VLLMAscendQuantizer): SUPPORT_ASCEND_QUANTIZER_TYPE = { + "W4A8_DYNAMIC": W4A8DYNAMICQuantizer, "W8A8": W8A8Quantizer, "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, "C8": W8A8Quantizer, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py new file mode 100644 index 0000000..c37b1c4 --- /dev/null +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, Optional + +import torch +import torch_npu +from vllm.config import get_current_vllm_config + + +class AscendW4A8DynamicLinearMethod: + """Linear method for Ascend W4A8_DYNAMIC + """ + + def __init__(self): + self.transpose_weight = True + try: + self.group_size = get_current_vllm_config( + ).quant_config.quant_description.get("group_size", 256) + except AttributeError: + self.group_size = 256 + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param(output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_scale_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + params_dict["weight_offset_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + return params_dict + + @staticmethod + def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, + per_group_scale: torch.Tensor): + k, n = weight.shape + group_num, n = per_group_scale.shape + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale.npu(), bias + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = None, + ) -> torch.Tensor: + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten().to( + torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_scale_second.data, scale_bias = self.process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + ) + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 09080de..d3bff93 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -84,6 +84,10 @@ class AscendW8A8LinearMethod: dtype=params_dtype) return params_dict + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + @staticmethod def apply( layer: torch.nn.Module, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 36549e7..b20ffa3 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -699,6 +699,10 @@ class AscendW8A8DynamicLinearMethod: dtype=params_dtype) return params_dict + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + @staticmethod def apply( layer: torch.nn.Module,