From 5f8b1699ae35bb8c046d8b385d215cd208bc3fcb Mon Sep 17 00:00:00 2001 From: Anion <123177548+Anionex@users.noreply.github.com> Date: Tue, 21 Oct 2025 20:18:39 +0800 Subject: [PATCH] [Feat][quantization] Support new version w4a8 dynamic quantization for Linear layers (#3311) ### What this PR does / why we need it? **Problem Description:** The existing implementation for the w4a8-dynamic linear method only supports the old quantization format from msmodelslim. When attempting to load models quantized with the new version, vLLM encounters errors due to mismatched tensor shapes and unprocessed quantization parameters. Relavant issues: - https://github.com/vllm-project/vllm-ascend/issues/3192 - https://github.com/vllm-project/vllm-ascend/issues/3152 **Proposed Changes:** 1. Add support for w4a8 dynamic(new format) in AscendW4A8DynamicLinearMethod and TorchairAscendW4A8DynamicLinearMethod 2. Add unit tests and e2e tests for w4a8 dynamic new and old format models
details 1. **Support for new w4a8-dynamic format:** * Detects quantization format by reading the "version" field in quant_description to ensure backward compatibility. * Handles the new pre-packed weight format (`2x int4` in an `int8`), which has a halved dimension. It tells the vLLM loader how to unpack it using `_packed_dim` and `_packed_factor`. * Supports the new `scale_bias` parameter, setting its shape based on the layer type, as required by msmodelslim. For api consistency and future use, the `layer_type` parameter was also added to other quantization methods. * Updates the weight processing logic: new format weights are handled with `.view(torch.int32)` since they're pre-packed, while old ones are processed with `npu_convert_weight_to_int4pack`. 2. **New unit and E2E tests:** * Added unit tests that verify the logic for both the old and new formats. * Split the distributed E2E test to confirm that both old and new format models work correctly.
Theoretically, these changes will provide support for all common new version w4a8(dynamic) models from msmodelslim. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? I implement relevant unit tests and e2e tests and test the changes with following commands: ```bash # unit tests python -m pytest tests/ut/quantization/test_w4a8_dynamic.py tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py -v # e2e tests pytest tests/e2e/singlecard/test_quantization.py -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC -v -s ``` I also tested Hunyuan-1.8B-Instruct quantized with the new w4a8-dynamic format: ``` vllm serve ./models/Hunyuan-1.8B-Instruct-quantized --gpu-memory-utilization 0.96 --quantization ascend --max-model-len 9600 --seed 0 --max-num-batched-tokens 16384 ``` All tests mentioned passed locally. **NOTE: I use quantization model from my own repo in test_offline_inference_distributed.py**. Here is the description: [Anionex/Qwen3-1.7B-W4A8-V1](https://modelscope.cn/models/Anionex/Qwen3-1.7B-W4A8-V1/summary) (including quantization steps).This should be replaced by a model in vllm-ascend ci modelscope repo. Thanks for reading! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Anionex <1005128408@qq.com> --- .github/workflows/_e2e_test.yaml | 3 +- .../test_offline_inference_distributed.py | 35 ++++- tests/ut/quantization/test_w4a8_dynamic.py | 103 ++++++++++++-- .../test_torchair_w4a8_dynamic.py | 91 +++++++++++- vllm_ascend/quantization/quant_config.py | 24 +++- vllm_ascend/quantization/w4a8_dynamic.py | 134 +++++++++++++++--- vllm_ascend/quantization/w8a8.py | 7 +- vllm_ascend/quantization/w8a8_dynamic.py | 7 +- .../quantization/torchair_w4a8_dynamic.py | 97 ++++++++++--- .../quantization/torchair_w8a8_dynamic.py | 7 +- 10 files changed, 433 insertions(+), 75 deletions(-) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index a330b62..ddbf4b3 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -185,7 +185,8 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W8A8 - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1 diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 60f3c1b..849bdf7 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -30,11 +30,20 @@ from vllm import SamplingParams from tests.e2e.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" QWEN_DENSE_MODELS = [ "vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8" ] +QWEN_W4A8_OLD_VERSION_MODELS = [ + "vllm-ascend/Qwen3-8B-W4A8", +] + +QWEN_W4A8_NEW_VERSION_MODELS = [ + "vllm-ascend/Qwen3-1.7B-W4A8-V1", +] + DEEPSEEK_W4A8_MODELS = [ "vllm-ascend/DeepSeek-V3-W4A8-Pruing", "vllm-ascend/DeepSeek-V3.1-W4A8-puring" @@ -98,20 +107,36 @@ def test_models_distributed_Qwen3_W8A8(): vllm_model.generate_greedy(example_prompts, max_tokens) -def test_models_distributed_Qwen3_W4A8DYNAMIC(): - example_prompts = [ +@pytest.mark.parametrize("model", QWEN_W4A8_OLD_VERSION_MODELS) +def test_models_distributed_Qwen3_W4A8DYNAMIC_old_version(model): + prompts = [ "Hello, my name is", ] max_tokens = 5 - with VllmRunner( - snapshot_download("vllm-ascend/Qwen3-8B-W4A8"), + snapshot_download(model), max_model_len=8192, dtype="auto", tensor_parallel_size=2, quantization="ascend", ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) + vllm_model.generate_greedy(prompts, max_tokens) + + +@pytest.mark.parametrize("model", QWEN_W4A8_NEW_VERSION_MODELS) +def test_models_distributed_Qwen3_W4A8DYNAMIC_new_version(model): + prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + with VllmRunner( + snapshot_download(model), + max_model_len=8192, + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(prompts, max_tokens) @pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS) diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index a14702b..2116b0c 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -9,25 +9,31 @@ from vllm_ascend.quantization.w4a8_dynamic import ( class TestAscendW4A8DynamicLinearMethod(TestBase): - def setUp(self): - with patch( - 'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config' - ) as mock_get_current_vllm_config: - mock_vllm_config = Mock() - mock_vllm_config.quant_config = Mock( - quant_description={"group_size": 256}) - mock_vllm_config.scheduler_config = Mock( - max_num_batched_tokens=2048, - max_model_len=2048, - enable_chunked_prefill=False) - mock_get_current_vllm_config.return_value = mock_vllm_config - self.method = AscendW4A8DynamicLinearMethod() - self.method.group_size = 8 + @patch('vllm.distributed.get_tensor_model_parallel_world_size') + @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') + def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size): + mock_get_tp_world_size.return_value = 1 + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock( + quant_description={"group_size": 256}) + mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048, + max_model_len=2048, + enable_chunked_prefill=False) + mock_get_current_vllm_config.return_value = mock_vllm_config + self.method = AscendW4A8DynamicLinearMethod() + self.method.group_size = 8 def test_get_weight(self): weight = self.method.get_weight(8, 32, torch.bfloat16) self.assertEqual(weight["weight"].dtype, torch.int8) self.assertEqual(weight["weight"].shape, (32, 8)) + # new quant version weight + self.method.new_quant_version = True + weight = self.method.get_weight(8, 32, torch.bfloat16) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (16, 8)) + self.assertEqual(weight["_packed_dim"], 0) + self.assertEqual(weight["_packed_factor"], 2) def test_get_pergroup_param(self): params = self.method.get_pergroup_param(8, 32, torch.bfloat16) @@ -39,6 +45,75 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): self.assertEqual(params["weight_scale_second"].shape, (32, 1)) self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) self.assertEqual(params["weight_offset_second"].shape, (32, 1)) + # new quant version weight + self.method.new_quant_version = True + params = self.method.get_pergroup_param(8, + 32, + torch.bfloat16, + layer_type="column") + self.assertEqual(params["scale_bias"].dtype, torch.float32) + self.assertEqual(params["scale_bias"].shape, (32, 1)) + params = self.method.get_pergroup_param(8, + 32, + torch.bfloat16, + layer_type="row") + self.assertEqual(params["scale_bias"].dtype, torch.float32) + self.assertEqual(params["scale_bias"].shape, (32, 16)) + + @patch('torch_npu.npu_convert_weight_to_int4pack') + @patch('torch.Tensor.npu') + def test_process_weights_after_loading(self, mock_npu, + mock_npu_convert_weight): + mock_npu.side_effect = lambda: torch.zeros( + (1, 32), dtype=torch.float32) + mock_npu_convert_weight.return_value = torch.zeros((32, 4), + dtype=torch.int32) + # old quant version weight + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter(torch.zeros((32, 8), + dtype=torch.int8), + requires_grad=False) + layer.weight_scale = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + layer.weight_offset = torch.nn.Parameter(torch.empty_like( + layer.weight_scale.data), + requires_grad=False) + layer.weight_scale_second = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + layer.weight_offset_second = torch.nn.Parameter(torch.empty_like( + layer.weight_scale_second.data), + requires_grad=False) + self.method.process_weights_after_loading(layer) + self.assertTrue(hasattr(layer, "weight_scale_bias")) + self.assertEqual(layer.weight_scale_bias.data.shape, (32, )) + self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32) + # new quant version weight + self.method.new_quant_version = True + new_layer = torch.nn.Module() + new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8), + dtype=torch.int8), + requires_grad=False) + new_layer.weight_scale = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + new_layer.weight_offset = torch.nn.Parameter(torch.empty_like( + new_layer.weight_scale.data), + requires_grad=False) + new_layer.weight_scale_second = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + new_layer.weight_offset_second = torch.nn.Parameter( + torch.empty_like(new_layer.weight_scale_second.data), + requires_grad=False) + new_layer.scale_bias = torch.nn.Parameter(torch.zeros( + (32, 1), dtype=torch.float32), + requires_grad=False) + self.method.process_weights_after_loading(new_layer) + self.assertEqual(new_layer.scale_bias.data.shape, (32, )) + self.assertTrue(hasattr(new_layer, "weight_scale_second")) + self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32)) class TestAscendW4A8DynamicFusedMoEMethod(TestBase): diff --git a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py index 9fd3f29..f29cafc 100644 --- a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py +++ b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py @@ -10,7 +10,16 @@ from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( class TestAscendW4A8DynamicLinearMethod(TestBase): - def setUp(self): + @patch('vllm.distributed.get_tensor_model_parallel_world_size') + @patch( + 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config' + ) + def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size): + mock_get_tp_world_size.return_value = 1 + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock( + quant_description={"group_size": 256}) + mock_get_current_vllm_config.return_value = mock_vllm_config self.method = TorchairAscendW4A8DynamicLinearMethod() self.method.group_size = 8 @@ -18,6 +27,13 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): weight = self.method.get_weight(8, 32, torch.bfloat16) self.assertEqual(weight["weight"].dtype, torch.int8) self.assertEqual(weight["weight"].shape, (32, 8)) + # new quant version weight + self.method.new_quant_version = True + weight = self.method.get_weight(8, 32, torch.bfloat16) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (16, 8)) + self.assertEqual(weight["_packed_dim"], 0) + self.assertEqual(weight["_packed_factor"], 2) def test_get_pergroup_param(self): params = self.method.get_pergroup_param(8, 32, torch.bfloat16) @@ -29,6 +45,75 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): self.assertEqual(params["weight_scale_second"].shape, (32, 1)) self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) self.assertEqual(params["weight_offset_second"].shape, (32, 1)) + # new quant version weight + self.method.new_quant_version = True + params = self.method.get_pergroup_param(8, + 32, + torch.bfloat16, + layer_type="column") + self.assertEqual(params["scale_bias"].dtype, torch.float32) + self.assertEqual(params["scale_bias"].shape, (32, 1)) + params = self.method.get_pergroup_param(8, + 32, + torch.bfloat16, + layer_type="row") + self.assertEqual(params["scale_bias"].dtype, torch.float32) + self.assertEqual(params["scale_bias"].shape, (32, 16)) + + @patch('torch_npu.npu_convert_weight_to_int4pack') + @patch('torch.Tensor.npu') + def test_process_weights_after_loading(self, mock_npu, + mock_npu_convert_weight): + mock_npu.side_effect = lambda: torch.zeros( + (1, 32), dtype=torch.float32) + mock_npu_convert_weight.return_value = torch.zeros((32, 4), + dtype=torch.int32) + # old quant version weight + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter(torch.zeros((32, 8), + dtype=torch.int8), + requires_grad=False) + layer.weight_scale = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + layer.weight_offset = torch.nn.Parameter(torch.empty_like( + layer.weight_scale.data), + requires_grad=False) + layer.weight_scale_second = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + layer.weight_offset_second = torch.nn.Parameter(torch.empty_like( + layer.weight_scale_second.data), + requires_grad=False) + self.method.process_weights_after_loading(layer) + self.assertTrue(hasattr(layer, "weight_scale_bias")) + self.assertEqual(layer.weight_scale_bias.data.shape, (32, )) + self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32) + # new quant version weight + self.method.new_quant_version = True + new_layer = torch.nn.Module() + new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8), + dtype=torch.int8), + requires_grad=False) + new_layer.weight_scale = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + new_layer.weight_offset = torch.nn.Parameter(torch.empty_like( + new_layer.weight_scale.data), + requires_grad=False) + new_layer.weight_scale_second = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + new_layer.weight_offset_second = torch.nn.Parameter( + torch.empty_like(new_layer.weight_scale_second.data), + requires_grad=False) + new_layer.scale_bias = torch.nn.Parameter(torch.zeros( + (32, 1), dtype=torch.float32), + requires_grad=False) + self.method.process_weights_after_loading(new_layer) + self.assertEqual(new_layer.scale_bias.data.shape, (32, )) + self.assertTrue(hasattr(new_layer, "weight_scale_second")) + self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32)) class TestAscendW4A8DynamicFusedMoEMethod(TestBase): @@ -42,7 +127,9 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): ) @patch( 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group') - @patch("vllm_ascend.ascend_config.get_ascend_config") + @patch( + 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ascend_config' + ) @patch( 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group' ) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index dea8c0d..e742852 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -271,9 +271,22 @@ class AscendLinearMethod(LinearMethodBase): weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, params_dtype) + + # Extract packing information (if present) + packed_dim = weight_dict.pop("_packed_dim", None) + packed_factor = weight_dict.pop("_packed_factor", None) + for weight_name, weight_param in weight_dict.items(): param = torch.nn.Parameter(weight_param, requires_grad=False) set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) + + # Set packing attributes if the weight is packed + if packed_dim is not None and packed_factor is not None: + set_weight_attrs(param, { + "packed_dim": packed_dim, + "packed_factor": packed_factor + }) + layer.register_parameter(weight_name, param) set_weight_attrs(param, extra_weight_attrs) @@ -294,8 +307,17 @@ class AscendLinearMethod(LinearMethodBase): layer.register_parameter(perchannel_name, param) set_weight_attrs(param, extra_weight_attrs) + # NOTE: In w4a8 quantization implementation, + # for down_proj and o_proj scale_bias shape is [output_size, 16], + # others are [output_size, 1] + layer_type = "row" if isinstance(layer, + RowParallelLinear) else "others" + pergroup_dict = self.quant_method.get_pergroup_param( - input_size_per_partition, output_size_per_partition, params_dtype) + input_size_per_partition, + output_size_per_partition, + params_dtype, + layer_type=layer_type) for pergroup_name, pergroup_param in pergroup_dict.items(): param = torch.nn.Parameter(pergroup_param, requires_grad=False) set_weight_attrs(param, {"output_dim": 0}) diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 4b5632f..2133607 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -36,18 +36,42 @@ class AscendW4A8DynamicLinearMethod: def __init__(self): self.transpose_weight = True - try: - self.group_size = get_current_vllm_config( - ).quant_config.quant_description.get("group_size", 256) - except AttributeError: - self.group_size = 256 - @staticmethod - def get_weight(input_size: int, output_size: int, + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + self.new_quant_version = quant_version == "1.0.0" + + from vllm.distributed import get_tensor_model_parallel_world_size + self.tp_size = get_tensor_model_parallel_world_size() + + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + """Create weight parameters. + + For new quantization version (double int4 pack into int8), the output dimension + is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned + dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader. + """ + params_dict = {} + + if self.new_quant_version: + # double int4 pack into int8: output dimension is compressed + pack_factor = 2 + actual_output_size = output_size // pack_factor + params_dict["weight"] = torch.empty(actual_output_size, + input_size, + dtype=torch.int8) + # Add packing information for vLLM's weight_loader + params_dict["_packed_dim"] = 0 + params_dict["_packed_factor"] = pack_factor + else: + params_dict["weight"] = torch.empty(output_size, + input_size, + dtype=torch.int8) + return params_dict @staticmethod @@ -59,8 +83,14 @@ class AscendW4A8DynamicLinearMethod: params_dtype: torch.dtype) -> Dict[str, Any]: return {} - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + """ + Create per-group quantization parameters. + """ params_dict = {} params_dict["weight_scale"] = torch.empty(output_size, 1, @@ -76,17 +106,52 @@ class AscendW4A8DynamicLinearMethod: input_size // self.group_size, dtype=params_dtype) + + # NOTE: In w4a8 quantization implementation, + # for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16], + # others are [output_size, 1] + if self.new_quant_version: + scale_bias_dim = 16 if layer_type == "row" else 1 + + params_dict["scale_bias"] = torch.empty(output_size, + scale_bias_dim, + dtype=torch.float32) return params_dict @staticmethod - def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, - per_group_scale: torch.Tensor): + def process_scale_second(weight: torch.Tensor, + scale: torch.Tensor, + per_group_scale: torch.Tensor, + is_new_quant: bool = False): + """ + Process the scale for second-level quantization. + + Args: + weight: weight tensor [k, n] (in new version, n is already compressed to n/2) + scale: first-level quantization scale [output_size] + per_group_scale: second-level per-group quantization scale [group_num, n_scale] + is_new_quant: whether it's the new quantization version (weight already compressed) + + Returns: + (antiquant_scale, bias): dequantization scale and bias (bias=None for new version) + """ k, n = weight.shape - group_num, n = per_group_scale.shape - weight_high = weight.to(torch.float32).reshape( - group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) - weight_high = weight_high.reshape(k, n) - bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + group_num, n_scale = per_group_scale.shape + + if is_new_quant: + # Restore logical dimension for compressed weight + n = n * 2 + + bias = None + if not is_new_quant: + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + # NOTE: scale_bias is not used currently + # because in msmodelslim w4a8 uses symmetric quantization + + # TODO: support potential future asymmetric quantization antiquant_scale = (scale * per_group_scale).reshape(group_num, n) return antiquant_scale.npu(), bias @@ -114,11 +179,34 @@ class AscendW4A8DynamicLinearMethod: layer.weight.data, layer.weight_scale.data, layer.weight_scale_second.data.transpose(0, 1).contiguous(), + is_new_quant=self.new_quant_version, ) - param = torch.nn.Parameter(scale_bias, requires_grad=False) - layer.register_parameter("weight_scale_bias", param) - layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( - layer.weight.data.to(torch.int32)) + + if self.new_quant_version: + # Process the loaded data based on layer type + if hasattr(layer, "scale_bias"): + if layer.scale_bias.data.shape[1] == 1: + layer.scale_bias.data = layer.scale_bias.data.flatten() + else: + layer.scale_bias.data = layer.scale_bias.data.contiguous() + else: + if scale_bias is not None: + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + + # Convert to NPU-specific int4pack format + if self.new_quant_version: + # weights on disk are already in packed int4 format + # pack 4 int8(int4*2) to int32 + assert layer.weight.data.shape[-1] % 4 == 0, \ + f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" + layer.weight.data = layer.weight.data.view( + torch.int32).contiguous() + else: + # weights are not compressed + # need to be packed via npu_convert_weight_to_int4pack + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) class AscendW4A8DynamicFusedMoEMethod: diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 615bece..9df640c 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -86,8 +86,11 @@ class AscendW8A8LinearMethod: dtype=params_dtype) return params_dict - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: return {} @staticmethod diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 8055b53..0f96e8c 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -62,8 +62,11 @@ class AscendW8A8DynamicLinearMethod: dtype=params_dtype) return params_dict - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: return {} @staticmethod diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py index 732e943..c61ddf3 100644 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -39,18 +39,34 @@ class TorchairAscendW4A8DynamicLinearMethod: def __init__(self): self.transpose_weight = True - try: - self.group_size = get_current_vllm_config( - ).quant_config.quant_description.get("group_size", 256) - except AttributeError: - self.group_size = 256 - @staticmethod - def get_weight(input_size: int, output_size: int, + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + self.new_quant_version = quant_version == "1.0.0" + + from vllm.distributed import get_tensor_model_parallel_world_size + self.tp_size = get_tensor_model_parallel_world_size() + + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + params_dict = {} + + if self.new_quant_version: + pack_factor = 2 + actual_output_size = output_size // pack_factor + params_dict["weight"] = torch.empty(actual_output_size, + input_size, + dtype=torch.int8) + params_dict["_packed_dim"] = 0 + params_dict["_packed_factor"] = pack_factor + else: + params_dict["weight"] = torch.empty(output_size, + input_size, + dtype=torch.int8) + return params_dict @staticmethod @@ -62,8 +78,11 @@ class TorchairAscendW4A8DynamicLinearMethod: params_dtype: torch.dtype) -> Dict[str, Any]: return {} - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: params_dict = {} params_dict["weight_scale"] = torch.empty(output_size, 1, @@ -79,17 +98,32 @@ class TorchairAscendW4A8DynamicLinearMethod: input_size // self.group_size, dtype=params_dtype) + + if self.new_quant_version: + scale_bias_dim = 16 if layer_type == "row" else 1 + params_dict["scale_bias"] = torch.empty(output_size, + scale_bias_dim, + dtype=torch.float32) return params_dict @staticmethod - def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, - per_group_scale: torch.Tensor): + def process_scale_second(weight: torch.Tensor, + scale: torch.Tensor, + per_group_scale: torch.Tensor, + is_new_quant: bool = False): k, n = weight.shape - group_num, n = per_group_scale.shape - weight_high = weight.to(torch.float32).reshape( - group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) - weight_high = weight_high.reshape(k, n) - bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + group_num, n_scale = per_group_scale.shape + + if is_new_quant: + n = n * 2 + + bias = None + if not is_new_quant: + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) return antiquant_scale.npu(), bias @@ -117,11 +151,28 @@ class TorchairAscendW4A8DynamicLinearMethod: layer.weight.data, layer.weight_scale.data, layer.weight_scale_second.data.transpose(0, 1).contiguous(), + is_new_quant=self.new_quant_version, ) - param = torch.nn.Parameter(scale_bias, requires_grad=False) - layer.register_parameter("weight_scale_bias", param) - layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( - layer.weight.data.to(torch.int32)) + + if self.new_quant_version: + if hasattr(layer, "scale_bias"): + if layer.scale_bias.data.shape[1] == 1: + layer.scale_bias.data = layer.scale_bias.data.flatten() + else: + layer.scale_bias.data = layer.scale_bias.data.contiguous() + else: + if scale_bias is not None: + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + + if self.new_quant_version: + assert layer.weight.data.shape[-1] % 4 == 0, \ + f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" + layer.weight.data = layer.weight.data.view( + torch.int32).contiguous() + else: + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) class TorchairAscendW4A8DynamicFusedMoEMethod: diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 23d59a8..bc0a8d3 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -793,8 +793,11 @@ class TorchairAscendW8A8DynamicLinearMethod: dtype=params_dtype) return params_dict - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: return {} @staticmethod