diff --git a/tests/ut/quantization/test_w4a4_flatquant_dynamic.py b/tests/ut/quantization/test_w4a4_flatquant_dynamic.py new file mode 100644 index 0000000..208ca7d --- /dev/null +++ b/tests/ut/quantization/test_w4a4_flatquant_dynamic.py @@ -0,0 +1,284 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn + +from vllm_ascend.quantization.w4a4_flatquant_dynamic import ( + AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim, + pack_int4_to_int32, pack_int4_weights) + + +class TestW4A4FlatQuantDynamic(unittest.TestCase): + """ + Unit test suite for AscendW4A4FlatQuantDynamicLinearMethod and its helper functions. + """ + + def setUp(self): + """Set up the test environment before each test.""" + self.method = AscendW4A4FlatQuantDynamicLinearMethod() + self.output_size = 64 + self.input_size = 768 # 768 = 24 * 32, divisible by 8 + self.params_dtype = torch.float16 + + ## Test Helper Functions + ## -------------------- + + def test_get_decompose_dim(self): + """ + Tests the get_decompose_dim function with various inputs. + """ + self.assertEqual(get_decompose_dim(1024), (32, 32)) + self.assertEqual(get_decompose_dim(768), (24, 32)) + self.assertEqual(get_decompose_dim(100), (10, 10)) + self.assertEqual(get_decompose_dim(99), (9, 11)) + + def test_pack_int4_to_int32(self): + """ + Tests manual packing of an int4 tensor into an int32 tensor. + """ + int4_tensor = torch.arange(-8, 8, dtype=torch.int8).view(2, 8) + expected_packed = torch.tensor([[1985229328], [-19088744]], + dtype=torch.int32) + packed_tensor = pack_int4_to_int32(int4_tensor) + self.assertTrue(torch.equal(packed_tensor, expected_packed)) + + def test_pack_int4_to_int32_value_error(self): + """ + Tests that pack_int4_to_int32 raises ValueError for invalid input shapes. + """ + invalid_tensor = torch.zeros((1, 7), dtype=torch.int8) + with self.assertRaisesRegex( + ValueError, "The last dimension must be a multiple of 8."): + pack_int4_to_int32(invalid_tensor) + + @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') + def test_pack_int4_weights_npu_success(self, mock_torch_npu): + """ + Tests weight packing using the mocked NPU kernel. + """ + weight_tensor = torch.randn(self.output_size, self.input_size) + mock_packed_tensor = torch.randint( + 0, + 100, (self.output_size, self.input_size // 8), + dtype=torch.int32) + mock_npu_tensor = MagicMock() + mock_npu_tensor.to.return_value = mock_packed_tensor + mock_torch_npu.npu_convert_weight_to_int4pack.return_value = mock_npu_tensor + with patch('torch.Tensor.npu', return_value=weight_tensor): + result = pack_int4_weights(weight_tensor) + + mock_torch_npu.npu_convert_weight_to_int4pack.assert_called_once() + self.assertTrue(torch.equal(result, mock_packed_tensor)) + + @patch( + 'vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_to_int32') + @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') + def test_pack_int4_weights_fallback(self, mock_torch_npu, + mock_pack_manual): + """ + Tests the fallback mechanism when the NPU kernel fails. + """ + with patch('torch.Tensor.npu', + side_effect=Exception("NPU not available")): + weight_tensor = torch.randn(self.output_size, self.input_size) + mock_pack_manual.return_value = "fallback success" + result = pack_int4_weights(weight_tensor) + mock_torch_npu.npu_convert_weight_to_int4pack.assert_not_called() + mock_pack_manual.assert_called_once_with(weight_tensor) + self.assertEqual(result, "fallback success") + + ## Test AscendW4A4FlatQuantDynamicLinearMethod Class + ## -------------------------------------------------- + + def test_get_weight(self): + """Tests the get_weight static method for correct output.""" + params = self.method.get_weight(self.input_size, self.output_size, + self.params_dtype) + self.assertIn("weight", params) + self.assertEqual(params["weight"].shape, + (self.output_size, self.input_size)) + self.assertEqual(params["weight"].dtype, torch.int8) + self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.input_size, + self.input_size) + self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.output_size, + self.output_size) + + def test_get_weight_value_error(self): + """Tests that get_weight raises ValueError for invalid input_size.""" + with self.assertRaisesRegex(ValueError, "must be divisible by 8"): + self.method.get_weight(127, self.output_size, self.params_dtype) + + def test_get_pertensor_param(self): + """Tests the get_pertensor_param static method.""" + self.method.get_weight(self.input_size, self.output_size, + self.params_dtype) + params = self.method.get_pertensor_param(self.params_dtype) + left_dim, right_dim = get_decompose_dim(self.input_size) + self.assertIn("left_trans", params) + self.assertIn("right_trans", params) + self.assertIn("clip_ratio", params) + self.assertEqual(params["left_trans"].shape, (left_dim, left_dim)) + self.assertEqual(params["right_trans"].shape, (right_dim, right_dim)) + self.assertEqual(params["clip_ratio"].shape, (1, )) + self.assertEqual(params["left_trans"].dtype, self.params_dtype) + self.assertEqual(params["clip_ratio"].dtype, torch.float32) + + def test_get_perchannel_param(self): + """Tests the get_perchannel_param static method.""" + params = self.method.get_perchannel_param(self.output_size, + self.params_dtype) + self.assertIn("weight_scale", params) + self.assertIn("weight_offset", params) + self.assertEqual(params["weight_scale"].shape, (self.output_size, 1)) + self.assertEqual(params["weight_offset"].shape, (self.output_size, 1)) + self.assertEqual(params["weight_scale"].dtype, torch.float32) + self.assertEqual(params["weight_offset"].dtype, torch.float32) + + def test_get_pergroup_param(self): + """Tests the get_pergroup_param method.""" + params = self.method.get_pergroup_param(self.input_size, + self.output_size, + self.params_dtype) + self.assertEqual(params, {}) + + def _prepare_apply_mocks_and_layer(self, batch_size): + """Helper to create a mock layer and input tensor for apply tests.""" + layer = nn.Module() + m, n = get_decompose_dim(self.input_size) + layer.left_trans = torch.randn(m, m, dtype=self.params_dtype) + layer.right_trans = torch.randn(n, n, dtype=self.params_dtype) + layer.aclnn_clip_ratio = 0.95 + layer.weight_packed = torch.randint( + -8, 7, (self.output_size, self.input_size // 8), dtype=torch.int32) + layer.weight_scale = torch.randn(self.output_size, + 1, + dtype=torch.float32) + x = torch.randn(batch_size, self.input_size, dtype=self.params_dtype) + return layer, x, m, n + + @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') + def test_apply_small_batch(self, mock_torch_npu): + """Tests the apply method with a batch size smaller than MAX_BATCH_SIZE.""" + batch_size = 128 + layer, x, m, n = self._prepare_apply_mocks_and_layer(batch_size) + mock_quant_x = torch.randint(0, + 255, (batch_size, self.input_size // 8), + dtype=torch.int32) + mock_act_scale = torch.randn(batch_size, 1, dtype=torch.float32) + mock_torch_npu.npu_kronecker_quant.return_value = (mock_quant_x.view( + batch_size, m, n // 8), mock_act_scale) + mock_output = torch.randn(batch_size, + self.output_size, + dtype=self.params_dtype) + mock_torch_npu.npu_quant_matmul.return_value = mock_output + bias = torch.randn(self.output_size, dtype=self.params_dtype) + output = self.method.apply(layer, x, bias=bias) + mock_torch_npu.npu_kronecker_quant.assert_called_once() + mock_torch_npu.npu_quant_matmul.assert_called_once() + self.assertTrue( + torch.allclose(output, mock_output + bias.to(self.params_dtype))) + self.assertEqual(output.shape, (batch_size, self.output_size)) + + @patch( + 'vllm_ascend.quantization.w4a4_flatquant_dynamic.KRONECKER_QUANT_MAX_BATCH_SIZE', + 10) + @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') + def test_apply_large_batch(self, mock_torch_npu): + """Tests the apply method with a batch size larger than MAX_BATCH_SIZE.""" + batch_size = 25 + layer, x, m, n = self._prepare_apply_mocks_and_layer(batch_size) + mock_quant_x = torch.randint(0, + 255, (batch_size, self.input_size // 8), + dtype=torch.int32) + mock_act_scale = torch.randn(batch_size, 1, dtype=torch.float32) + mock_torch_npu.npu_kronecker_quant.side_effect = [ + (mock_quant_x[:10].view(10, m, n // 8), mock_act_scale[:10]), + (mock_quant_x[10:20].view(10, m, n // 8), mock_act_scale[10:20]), + (mock_quant_x[20:].view(5, m, n // 8), mock_act_scale[20:]), + ] + mock_output = torch.randn(batch_size, + self.output_size, + dtype=self.params_dtype) + mock_torch_npu.npu_quant_matmul.return_value = mock_output + output = self.method.apply(layer, x, bias=None) + self.assertEqual(mock_torch_npu.npu_kronecker_quant.call_count, 3) + mock_torch_npu.npu_quant_matmul.assert_called_once() + self.assertTrue(torch.equal(output, mock_output)) + self.assertEqual(output.shape, (batch_size, self.output_size)) + + def test_apply_dimension_mismatch_error(self): + """Tests that apply raises ValueError on transform matrix dimension mismatch.""" + layer, x, _, _ = self._prepare_apply_mocks_and_layer(16) + layer.left_trans = torch.randn(20, 20) + layer.right_trans = torch.randn(30, 30) # 20 * 30 != 768 + with self.assertRaisesRegex( + ValueError, "FlatQuant transform matrices dimension mismatch"): + self.method.apply(layer, x) + + @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights') + def test_process_weights_after_loading(self, mock_pack_weights): + """Tests weight processing after loading, without transpose.""" + layer = nn.Module() + layer.weight = torch.randint(-8, + 7, (self.output_size, self.input_size), + dtype=torch.int8) + layer.weight_scale = torch.randn(self.output_size, + 1, + dtype=torch.bfloat16) + layer.weight_offset = torch.randn(self.output_size, + 1, + dtype=torch.bfloat16) + layer.left_trans = torch.randn(24, 24) + layer.right_trans = torch.randn(32, 32) + layer.clip_ratio = torch.tensor([0.9]) + mock_packed = torch.randint(0, + 100, + (self.output_size, self.input_size // 8), + dtype=torch.int32) + mock_pack_weights.return_value = mock_packed + self.method.transpose_weight = False + self.method.process_weights_after_loading(layer) + mock_pack_weights.assert_called_once() + self.assertFalse(hasattr(layer, 'weight')) + self.assertTrue(hasattr(layer, 'weight_packed')) + self.assertTrue(torch.equal(layer.weight_packed.data, mock_packed)) + self.assertEqual(layer.weight_scale.dtype, torch.float32) + self.assertEqual(layer.weight_offset.dtype, torch.float32) + self.assertEqual(layer.clip_ratio.dtype, torch.float32) + self.assertTrue(layer.aclnn_clip_ratio - 0.9 < 0.01) + self.assertEqual(layer.left_trans.shape, (24, 24)) + self.assertTrue(layer.left_trans.is_contiguous()) + + @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights') + def test_process_weights_after_loading_with_transpose( + self, mock_pack_weights): + """Tests weight processing after loading, with transpose.""" + layer = nn.Module() + layer.weight = torch.randint(-8, + 7, (self.output_size, self.input_size), + dtype=torch.int8) + layer.weight_scale = torch.randn(self.output_size, + 1, + dtype=torch.bfloat16) + layer.weight_offset = torch.randn(self.output_size, + 1, + dtype=torch.bfloat16) + layer.left_trans = torch.randn(24, 24) + layer.right_trans = torch.randn(32, 32) + layer.clip_ratio = torch.tensor([0.9]) + mock_packed = torch.randint(0, + 100, + (self.output_size, self.input_size // 8), + dtype=torch.int32) + mock_pack_weights.return_value = mock_packed + self.method.transpose_weight = True + self.method.process_weights_after_loading(layer) + self.assertTrue(hasattr(layer, 'weight_packed')) + self.assertEqual(layer.weight_packed.shape, + (self.input_size // 8, self.output_size)) + self.assertTrue(layer.weight_packed.is_contiguous()) + + +if __name__ == '__main__': + unittest.main(argv=['first-arg-is-ignored'], exit=False) diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index dc5845a..6d914c0 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional, Type from vllm.logger import logger +from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod) from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, @@ -14,6 +15,9 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { "linear": AscendW4A8DynamicLinearMethod, "moe": AscendW4A8DynamicFusedMoEMethod, }, + "W4A4_FLATQUANT_DYNAMIC": { + "linear": AscendW4A4FlatQuantDynamicLinearMethod, + }, "W8A8": { "linear": AscendW8A8LinearMethod, "moe": AscendW8A8FusedMoEMethod, diff --git a/vllm_ascend/quantization/w4a4_flatquant_dynamic.py b/vllm_ascend/quantization/w4a4_flatquant_dynamic.py new file mode 100644 index 0000000..4398ca5 --- /dev/null +++ b/vllm_ascend/quantization/w4a4_flatquant_dynamic.py @@ -0,0 +1,223 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math +from typing import Any, Dict, Optional + +import torch +import torch_npu + +KRONECKER_QUANT_MAX_BATCH_SIZE = 8192 + + +def pack_int4_to_int32(int4_tensor: torch.Tensor) -> torch.Tensor: + """ + Packs a tensor of 4-bit integers into a tensor of 32-bit integers. + + This function serves as a manual, device-agnostic fallback when a more + optimized hardware-specific kernel (like for an NPU) is not available. + It processes the tensor along its last dimension. + + Args: + int4_tensor: A tensor with a dtype that can be represented as int4. + The size of its last dimension must be a multiple of 8. + + Returns: + A new tensor of dtype torch.int32 where every 8 values from the + original tensor's last dimension are packed into a single int32 value. + """ + if int4_tensor.shape[-1] % 8 != 0: + raise ValueError("The last dimension must be a multiple of 8.") + int4_clamped = torch.clamp(int4_tensor, -8, 7) + uint4_tensor = int4_clamped.to(torch.uint8) + 8 + original_shape = uint4_tensor.shape + packed_shape = list(original_shape[:-1]) + [original_shape[-1] // 8] + uint4_reshaped = uint4_tensor.view(*original_shape[:-1], -1, 8) + packed_tensor = torch.zeros(*packed_shape, + dtype=torch.int32, + device=uint4_tensor.device) + for i in range(8): + packed_tensor += (uint4_reshaped[..., i].to(torch.int32) << (i * 4)) + return packed_tensor + + +def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor: + """ + Packs a weight tensor from int4 to int32, using an NPU-accelerated + kernel if available, otherwise falling back to a manual implementation. + """ + try: + original_device = weight_tensor.device + weight_tensor_npu = weight_tensor.npu() + weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack( + weight_tensor_npu.to(torch.int32), inner_k_tiles=1) + return weight_int4_packed.to(original_device) + except Exception as e: + print( + f"Warning: NPU kernel 'npu_convert_weight_to_int4pack' is not available. " + f"Falling back to a manual packing implementation. Error: {e}") + return pack_int4_to_int32(weight_tensor) + + +def get_decompose_dim(n): + a = int(math.sqrt(n)) + if a * a < n: + a += 1 + while True: + tmp = a * a - n + b = int(math.sqrt(tmp)) + if b * b == tmp: + break + a += 1 + return a - b, a + b + + +class AscendW4A4FlatQuantDynamicLinearMethod: + """Linear method for Ascend W4A4_FLATQUANT_DYNAMIC. + + This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization. + - Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8 and packed to int32 during loading + - Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for distribution smoothing + - Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded from external weights + """ + input_size = 0 + output_size = 0 + + def __init__(self): + self.transpose_weight = False + self.sym = True + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size % 8 != 0: + raise ValueError( + f"input_size ({input_size}) must be divisible by 8 for int4 packing" + ) + AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size + AscendW4A4FlatQuantDynamicLinearMethod.output_size = output_size + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + left_trans_dim, right_trans_dim = get_decompose_dim( + AscendW4A4FlatQuantDynamicLinearMethod.input_size) + params_dict["left_trans"] = torch.empty(left_trans_dim, + left_trans_dim, + dtype=params_dtype) + params_dict["right_trans"] = torch.empty(right_trans_dim, + right_trans_dim, + dtype=params_dtype) + params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32) + return params_dict + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=torch.float32) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=torch.float32) + return params_dict + + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + original_dtype = x.dtype + input_shape = x.shape + in_features = input_shape[-1] + M = layer.left_trans.shape[0] + N = layer.right_trans.shape[0] + if M * N != in_features: + raise ValueError( + f"FlatQuant transform matrices dimension mismatch: M({M}) * N({N}) != in_features({in_features})" + ) + left_trans_matched = layer.left_trans.to(original_dtype) + right_trans_matched = layer.right_trans.to(original_dtype) + x_reshaped = x.view(-1, M, N) + batch_tokens = x_reshaped.shape[0] + if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE: + x_quantized_int4, activation_scale = torch_npu.npu_kronecker_quant( + x_reshaped, + left_trans_matched, + right_trans_matched, + clip_ratio=layer.aclnn_clip_ratio, + dst_dtype=torch.int32) + else: + x_quantized_int4_list = [] + activation_scale_list = [] + for start_idx in range(0, batch_tokens, + KRONECKER_QUANT_MAX_BATCH_SIZE): + end_idx = min(start_idx + KRONECKER_QUANT_MAX_BATCH_SIZE, + batch_tokens) + x_batch = x_reshaped[start_idx:end_idx] + x_quantized_batch, activation_scale_batch = torch_npu.npu_kronecker_quant( + x_batch, + left_trans_matched, + right_trans_matched, + clip_ratio=layer.aclnn_clip_ratio, + dst_dtype=torch.int32) + x_quantized_int4_list.append(x_quantized_batch) + activation_scale_list.append(activation_scale_batch) + x_quantized_int4 = torch.cat(x_quantized_int4_list, dim=0) + activation_scale = torch.cat(activation_scale_list, dim=0) + x_quantized_reshaped = x_quantized_int4.view(-1, M * N // 8) + pertoken_scale = activation_scale.view(-1).to(torch.float32) + output = torch_npu.npu_quant_matmul(x_quantized_reshaped, + layer.weight_packed.t(), + layer.weight_scale.view(-1).to( + torch.float32), + pertoken_scale=pertoken_scale, + bias=None, + output_dtype=original_dtype) + output = output.view(*input_shape[:-1], -1) + if bias is not None: + output = output + bias.to(original_dtype) + return output + + def process_weights_after_loading(self, layer): + weight_packed = pack_int4_weights(layer.weight.data) + if self.transpose_weight: + weight_packed = weight_packed.transpose(0, 1).contiguous() + layer.register_parameter( + 'weight_packed', + torch.nn.Parameter(weight_packed, requires_grad=False)) + del layer.weight + layer.weight_scale.data = layer.weight_scale.data.to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.to(torch.float32) + layer.left_trans = torch.nn.Parameter( + layer.left_trans.data.t().contiguous()) + layer.right_trans = torch.nn.Parameter(layer.right_trans.data) + layer.clip_ratio = torch.nn.Parameter( + layer.clip_ratio.data.to(torch.float32)) + layer.aclnn_clip_ratio = layer.clip_ratio.item()