[Refactor] Clean up w4a4_flatquant_dynamic implementation (#3440)
Cleans up the initial implementation of `w4a4_flatquant_dynamic` for better readability and maintainability. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm_ascend.quantization.w4a4_flatquant_dynamic import (
|
||||
AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim,
|
||||
pack_int4_to_int32, pack_int4_weights)
|
||||
pack_int4_weights)
|
||||
|
||||
|
||||
class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
@@ -33,25 +33,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
self.assertEqual(get_decompose_dim(100), (10, 10))
|
||||
self.assertEqual(get_decompose_dim(99), (9, 11))
|
||||
|
||||
def test_pack_int4_to_int32(self):
|
||||
"""
|
||||
Tests manual packing of an int4 tensor into an int32 tensor.
|
||||
"""
|
||||
int4_tensor = torch.arange(-8, 8, dtype=torch.int8).view(2, 8)
|
||||
expected_packed = torch.tensor([[1985229328], [-19088744]],
|
||||
dtype=torch.int32)
|
||||
packed_tensor = pack_int4_to_int32(int4_tensor)
|
||||
self.assertTrue(torch.equal(packed_tensor, expected_packed))
|
||||
|
||||
def test_pack_int4_to_int32_value_error(self):
|
||||
"""
|
||||
Tests that pack_int4_to_int32 raises ValueError for invalid input shapes.
|
||||
"""
|
||||
invalid_tensor = torch.zeros((1, 7), dtype=torch.int8)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The last dimension must be a multiple of 8."):
|
||||
pack_int4_to_int32(invalid_tensor)
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
def test_pack_int4_weights_npu_success(self, mock_torch_npu):
|
||||
"""
|
||||
@@ -71,23 +52,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
mock_torch_npu.npu_convert_weight_to_int4pack.assert_called_once()
|
||||
self.assertTrue(torch.equal(result, mock_packed_tensor))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_to_int32')
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
def test_pack_int4_weights_fallback(self, mock_torch_npu,
|
||||
mock_pack_manual):
|
||||
"""
|
||||
Tests the fallback mechanism when the NPU kernel fails.
|
||||
"""
|
||||
with patch('torch.Tensor.npu',
|
||||
side_effect=Exception("NPU not available")):
|
||||
weight_tensor = torch.randn(self.output_size, self.input_size)
|
||||
mock_pack_manual.return_value = "fallback success"
|
||||
result = pack_int4_weights(weight_tensor)
|
||||
mock_torch_npu.npu_convert_weight_to_int4pack.assert_not_called()
|
||||
mock_pack_manual.assert_called_once_with(weight_tensor)
|
||||
self.assertEqual(result, "fallback success")
|
||||
|
||||
## Test AscendW4A4FlatQuantDynamicLinearMethod Class
|
||||
## --------------------------------------------------
|
||||
|
||||
@@ -101,8 +65,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
self.assertEqual(params["weight"].dtype, torch.int8)
|
||||
self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.input_size,
|
||||
self.input_size)
|
||||
self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.output_size,
|
||||
self.output_size)
|
||||
|
||||
def test_get_weight_value_error(self):
|
||||
"""Tests that get_weight raises ValueError for invalid input_size."""
|
||||
|
||||
Reference in New Issue
Block a user