[Refactor] Clean up w4a4_flatquant_dynamic implementation (#3440)

Cleans up the initial implementation of `w4a4_flatquant_dynamic` for
better readability and maintainability.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Slightwind
2025-10-17 23:53:19 +08:00
committed by GitHub
parent 21769e8f44
commit 07ca1b9b78
2 changed files with 50 additions and 121 deletions

View File

@@ -6,7 +6,7 @@ import torch.nn as nn
from vllm_ascend.quantization.w4a4_flatquant_dynamic import (
AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim,
pack_int4_to_int32, pack_int4_weights)
pack_int4_weights)
class TestW4A4FlatQuantDynamic(unittest.TestCase):
@@ -33,25 +33,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
self.assertEqual(get_decompose_dim(100), (10, 10))
self.assertEqual(get_decompose_dim(99), (9, 11))
def test_pack_int4_to_int32(self):
"""
Tests manual packing of an int4 tensor into an int32 tensor.
"""
int4_tensor = torch.arange(-8, 8, dtype=torch.int8).view(2, 8)
expected_packed = torch.tensor([[1985229328], [-19088744]],
dtype=torch.int32)
packed_tensor = pack_int4_to_int32(int4_tensor)
self.assertTrue(torch.equal(packed_tensor, expected_packed))
def test_pack_int4_to_int32_value_error(self):
"""
Tests that pack_int4_to_int32 raises ValueError for invalid input shapes.
"""
invalid_tensor = torch.zeros((1, 7), dtype=torch.int8)
with self.assertRaisesRegex(
ValueError, "The last dimension must be a multiple of 8."):
pack_int4_to_int32(invalid_tensor)
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
def test_pack_int4_weights_npu_success(self, mock_torch_npu):
"""
@@ -71,23 +52,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
mock_torch_npu.npu_convert_weight_to_int4pack.assert_called_once()
self.assertTrue(torch.equal(result, mock_packed_tensor))
@patch(
'vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_to_int32')
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
def test_pack_int4_weights_fallback(self, mock_torch_npu,
mock_pack_manual):
"""
Tests the fallback mechanism when the NPU kernel fails.
"""
with patch('torch.Tensor.npu',
side_effect=Exception("NPU not available")):
weight_tensor = torch.randn(self.output_size, self.input_size)
mock_pack_manual.return_value = "fallback success"
result = pack_int4_weights(weight_tensor)
mock_torch_npu.npu_convert_weight_to_int4pack.assert_not_called()
mock_pack_manual.assert_called_once_with(weight_tensor)
self.assertEqual(result, "fallback success")
## Test AscendW4A4FlatQuantDynamicLinearMethod Class
## --------------------------------------------------
@@ -101,8 +65,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
self.assertEqual(params["weight"].dtype, torch.int8)
self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.input_size,
self.input_size)
self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.output_size,
self.output_size)
def test_get_weight_value_error(self):
"""Tests that get_weight raises ValueError for invalid input_size."""