diff --git a/tests/ut/quantization/test_w4a4_flatquant_dynamic.py b/tests/ut/quantization/test_w4a4_flatquant_dynamic.py index 208ca7d..ce8cfec 100644 --- a/tests/ut/quantization/test_w4a4_flatquant_dynamic.py +++ b/tests/ut/quantization/test_w4a4_flatquant_dynamic.py @@ -6,7 +6,7 @@ import torch.nn as nn from vllm_ascend.quantization.w4a4_flatquant_dynamic import ( AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim, - pack_int4_to_int32, pack_int4_weights) + pack_int4_weights) class TestW4A4FlatQuantDynamic(unittest.TestCase): @@ -33,25 +33,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase): self.assertEqual(get_decompose_dim(100), (10, 10)) self.assertEqual(get_decompose_dim(99), (9, 11)) - def test_pack_int4_to_int32(self): - """ - Tests manual packing of an int4 tensor into an int32 tensor. - """ - int4_tensor = torch.arange(-8, 8, dtype=torch.int8).view(2, 8) - expected_packed = torch.tensor([[1985229328], [-19088744]], - dtype=torch.int32) - packed_tensor = pack_int4_to_int32(int4_tensor) - self.assertTrue(torch.equal(packed_tensor, expected_packed)) - - def test_pack_int4_to_int32_value_error(self): - """ - Tests that pack_int4_to_int32 raises ValueError for invalid input shapes. - """ - invalid_tensor = torch.zeros((1, 7), dtype=torch.int8) - with self.assertRaisesRegex( - ValueError, "The last dimension must be a multiple of 8."): - pack_int4_to_int32(invalid_tensor) - @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') def test_pack_int4_weights_npu_success(self, mock_torch_npu): """ @@ -71,23 +52,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase): mock_torch_npu.npu_convert_weight_to_int4pack.assert_called_once() self.assertTrue(torch.equal(result, mock_packed_tensor)) - @patch( - 'vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_to_int32') - @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') - def test_pack_int4_weights_fallback(self, mock_torch_npu, - mock_pack_manual): - """ - Tests the fallback mechanism when the NPU kernel fails. - """ - with patch('torch.Tensor.npu', - side_effect=Exception("NPU not available")): - weight_tensor = torch.randn(self.output_size, self.input_size) - mock_pack_manual.return_value = "fallback success" - result = pack_int4_weights(weight_tensor) - mock_torch_npu.npu_convert_weight_to_int4pack.assert_not_called() - mock_pack_manual.assert_called_once_with(weight_tensor) - self.assertEqual(result, "fallback success") - ## Test AscendW4A4FlatQuantDynamicLinearMethod Class ## -------------------------------------------------- @@ -101,8 +65,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase): self.assertEqual(params["weight"].dtype, torch.int8) self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.input_size, self.input_size) - self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.output_size, - self.output_size) def test_get_weight_value_error(self): """Tests that get_weight raises ValueError for invalid input_size.""" diff --git a/vllm_ascend/quantization/w4a4_flatquant_dynamic.py b/vllm_ascend/quantization/w4a4_flatquant_dynamic.py index 4398ca5..efc643c 100644 --- a/vllm_ascend/quantization/w4a4_flatquant_dynamic.py +++ b/vllm_ascend/quantization/w4a4_flatquant_dynamic.py @@ -15,61 +15,20 @@ # limitations under the License. # import math -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import torch import torch_npu -KRONECKER_QUANT_MAX_BATCH_SIZE = 8192 - - -def pack_int4_to_int32(int4_tensor: torch.Tensor) -> torch.Tensor: - """ - Packs a tensor of 4-bit integers into a tensor of 32-bit integers. - - This function serves as a manual, device-agnostic fallback when a more - optimized hardware-specific kernel (like for an NPU) is not available. - It processes the tensor along its last dimension. - - Args: - int4_tensor: A tensor with a dtype that can be represented as int4. - The size of its last dimension must be a multiple of 8. - - Returns: - A new tensor of dtype torch.int32 where every 8 values from the - original tensor's last dimension are packed into a single int32 value. - """ - if int4_tensor.shape[-1] % 8 != 0: - raise ValueError("The last dimension must be a multiple of 8.") - int4_clamped = torch.clamp(int4_tensor, -8, 7) - uint4_tensor = int4_clamped.to(torch.uint8) + 8 - original_shape = uint4_tensor.shape - packed_shape = list(original_shape[:-1]) + [original_shape[-1] // 8] - uint4_reshaped = uint4_tensor.view(*original_shape[:-1], -1, 8) - packed_tensor = torch.zeros(*packed_shape, - dtype=torch.int32, - device=uint4_tensor.device) - for i in range(8): - packed_tensor += (uint4_reshaped[..., i].to(torch.int32) << (i * 4)) - return packed_tensor +KRONECKER_QUANT_MAX_BATCH_SIZE = 32768 def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor: - """ - Packs a weight tensor from int4 to int32, using an NPU-accelerated - kernel if available, otherwise falling back to a manual implementation. - """ - try: - original_device = weight_tensor.device - weight_tensor_npu = weight_tensor.npu() - weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack( - weight_tensor_npu.to(torch.int32), inner_k_tiles=1) - return weight_int4_packed.to(original_device) - except Exception as e: - print( - f"Warning: NPU kernel 'npu_convert_weight_to_int4pack' is not available. " - f"Falling back to a manual packing implementation. Error: {e}") - return pack_int4_to_int32(weight_tensor) + original_device = weight_tensor.device + weight_tensor_npu = weight_tensor.npu() + weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack( + weight_tensor_npu.to(torch.int32), inner_k_tiles=1) + return weight_int4_packed.to(original_device) def get_decompose_dim(n): @@ -85,6 +44,37 @@ def get_decompose_dim(n): return a - b, a + b +# TODO: This function is a temporary workaround for the npu_kronecker_quant operator, +# which has a limitation on the maximum batch size (dim0). This wrapper should be +# removed once the operator supports larger inputs natively. +def batched_kronecker_quant( + x: torch.Tensor, + left_trans: torch.Tensor, + right_trans: torch.Tensor, + clip_ratio: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch_tokens = x.shape[0] + if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE: + return torch_npu.npu_kronecker_quant(x, + left_trans, + right_trans, + clip_ratio=clip_ratio, + dst_dtype=torch.int32) + x_chunks = torch.split(x, KRONECKER_QUANT_MAX_BATCH_SIZE, dim=0) + processed_chunks = [ + torch_npu.npu_kronecker_quant(chunk, + left_trans, + right_trans, + clip_ratio=clip_ratio, + dst_dtype=torch.int32) + for chunk in x_chunks + ] + quantized_list, scale_list = zip(*processed_chunks) + x_quantized_int4 = torch.cat(quantized_list, dim=0) + activation_scale = torch.cat(scale_list, dim=0) + return x_quantized_int4, activation_scale + + class AscendW4A4FlatQuantDynamicLinearMethod: """Linear method for Ascend W4A4_FLATQUANT_DYNAMIC. @@ -94,7 +84,6 @@ class AscendW4A4FlatQuantDynamicLinearMethod: - Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded from external weights """ input_size = 0 - output_size = 0 def __init__(self): self.transpose_weight = False @@ -108,7 +97,6 @@ class AscendW4A4FlatQuantDynamicLinearMethod: f"input_size ({input_size}) must be divisible by 8 for int4 packing" ) AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size - AscendW4A4FlatQuantDynamicLinearMethod.output_size = output_size params_dict = { "weight": torch.empty(output_size, input_size, dtype=torch.int8) } @@ -156,42 +144,21 @@ class AscendW4A4FlatQuantDynamicLinearMethod: original_dtype = x.dtype input_shape = x.shape in_features = input_shape[-1] - M = layer.left_trans.shape[0] - N = layer.right_trans.shape[0] - if M * N != in_features: + left_dim = layer.left_trans.shape[0] + right_dim = layer.right_trans.shape[0] + if left_dim * right_dim != in_features: raise ValueError( - f"FlatQuant transform matrices dimension mismatch: M({M}) * N({N}) != in_features({in_features})" + f"FlatQuant transform matrices dimension mismatch: " + f"left_dim({left_dim}) * right_dim({right_dim}) != in_features({in_features})" ) left_trans_matched = layer.left_trans.to(original_dtype) right_trans_matched = layer.right_trans.to(original_dtype) - x_reshaped = x.view(-1, M, N) - batch_tokens = x_reshaped.shape[0] - if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE: - x_quantized_int4, activation_scale = torch_npu.npu_kronecker_quant( - x_reshaped, - left_trans_matched, - right_trans_matched, - clip_ratio=layer.aclnn_clip_ratio, - dst_dtype=torch.int32) - else: - x_quantized_int4_list = [] - activation_scale_list = [] - for start_idx in range(0, batch_tokens, - KRONECKER_QUANT_MAX_BATCH_SIZE): - end_idx = min(start_idx + KRONECKER_QUANT_MAX_BATCH_SIZE, - batch_tokens) - x_batch = x_reshaped[start_idx:end_idx] - x_quantized_batch, activation_scale_batch = torch_npu.npu_kronecker_quant( - x_batch, - left_trans_matched, - right_trans_matched, - clip_ratio=layer.aclnn_clip_ratio, - dst_dtype=torch.int32) - x_quantized_int4_list.append(x_quantized_batch) - activation_scale_list.append(activation_scale_batch) - x_quantized_int4 = torch.cat(x_quantized_int4_list, dim=0) - activation_scale = torch.cat(activation_scale_list, dim=0) - x_quantized_reshaped = x_quantized_int4.view(-1, M * N // 8) + x_reshaped = x.view(-1, left_dim, right_dim) + x_quantized_int4, activation_scale = batched_kronecker_quant( + x_reshaped, left_trans_matched, right_trans_matched, + layer.aclnn_clip_ratio) + x_quantized_reshaped = x_quantized_int4.view(-1, + left_dim * right_dim // 8) pertoken_scale = activation_scale.view(-1).to(torch.float32) output = torch_npu.npu_quant_matmul(x_quantized_reshaped, layer.weight_packed.t(),