[Feat.][310P]: weightNZ feature with quant or unquant. (#6705)
NZ Format Support for Linear Layers: Implemented support for the NZ
(N-dimensional Z-order) format for linear layer weights on Ascend 310P,
enhancing performance for both quantized and unquantized layers.
Unquantized Linear Method for Ascend 310P: Introduced
AscendUnquantizedLinearMethod310 to specifically handle and apply NZ
format casting to unquantized linear layer weights during the loading
process.
MRotaryEmbedding Integration: Extended Rotary Embedding support by
adding AscendMRotaryEmbedding310 to provide an Ascend-specific
implementation for MRotaryEmbedding.
Quantization Method Updates: Updated the w8a8_static quantization method
to directly transpose weights and apply NZ format casting, ensuring
consistency with the new format.
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -21,8 +21,8 @@ from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend._310p.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod310
|
||||
from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310
|
||||
from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
|
||||
|
||||
class TestAscendModelSlimConfig310(TestBase):
|
||||
@@ -50,7 +50,7 @@ class TestAscendModelSlimConfig310(TestBase):
|
||||
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True),
|
||||
):
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
|
||||
self.assertIsInstance(method, AscendUnquantizedLinearMethod310)
|
||||
|
||||
# Test quantized layer
|
||||
mock_scheme = MagicMock()
|
||||
|
||||
@@ -44,6 +44,7 @@ class TestAscendW8A8LinearMethod310(TestBase):
|
||||
self.assertEqual(params["deq_scale"].dtype, torch.int64)
|
||||
self.assertEqual(params["weight_scale"].dtype, torch.float16)
|
||||
self.assertEqual(params["weight_offset"].dtype, torch.float16)
|
||||
|
||||
self.assertEqual(params["quant_bias"].shape, (10,))
|
||||
self.assertEqual(params["deq_scale"].shape, (10,))
|
||||
self.assertEqual(params["weight_scale"].shape, (10, 1))
|
||||
@@ -71,12 +72,23 @@ class TestAscendW8A8LinearMethod310(TestBase):
|
||||
output = self.method.apply(layer, x, tp_rank=0)
|
||||
|
||||
mock_quantize.assert_called_with(
|
||||
x, layer.aclnn_input_scale, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset
|
||||
x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
mock_npu_quant_matmul.assert_called_with(
|
||||
expect_x_output, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
|
||||
)
|
||||
# The bias is added by the linear layer's forward pass, not the quant method.
|
||||
mock_npu_quant_matmul.assert_called_once()
|
||||
(args, kwargs) = mock_npu_quant_matmul.call_args
|
||||
|
||||
# positional args
|
||||
self.assertTrue(torch.equal(args[0], expect_x_output))
|
||||
self.assertTrue(torch.equal(args[1], layer.weight.data))
|
||||
self.assertTrue(torch.equal(args[2], layer.deq_scale))
|
||||
|
||||
# kwargs
|
||||
self.assertTrue(torch.equal(kwargs["bias"], layer.quant_bias))
|
||||
self.assertEqual(kwargs["output_dtype"], layer.params_dtype)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch("torch.ops.vllm.quantize")
|
||||
@@ -98,8 +110,41 @@ class TestAscendW8A8LinearMethod310(TestBase):
|
||||
output = self.method.apply(layer, x, tp_rank=0)
|
||||
|
||||
mock_quantize.assert_not_called()
|
||||
mock_npu_quant_matmul.assert_called_with(
|
||||
x, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
|
||||
)
|
||||
# The bias is added by the linear layer's forward pass, not the quant method.
|
||||
mock_npu_quant_matmul.assert_called_once()
|
||||
(args, kwargs) = mock_npu_quant_matmul.call_args
|
||||
|
||||
self.assertTrue(torch.equal(args[0], x))
|
||||
self.assertTrue(torch.equal(args[1], layer.weight.data))
|
||||
self.assertTrue(torch.equal(args[2], layer.deq_scale))
|
||||
|
||||
self.assertTrue(torch.equal(kwargs["bias"], layer.quant_bias))
|
||||
self.assertEqual(kwargs["output_dtype"], layer.params_dtype)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch("torch_npu.npu_format_cast")
|
||||
def test_process_weights_after_loading_calls_nz_format_cast_310p(self, mock_npu_format_cast):
|
||||
mock_npu_format_cast.side_effect = lambda x, fmt: x
|
||||
|
||||
layer = MagicMock()
|
||||
|
||||
# Attributes used by process_weights_after_loading()
|
||||
layer.weight = MagicMock()
|
||||
layer.input_scale = MagicMock()
|
||||
layer.input_offset = MagicMock()
|
||||
layer.weight_scale = MagicMock()
|
||||
layer.weight_offset = MagicMock()
|
||||
layer.w2_weight_offset = MagicMock()
|
||||
|
||||
layer.weight.data = torch.randint(-127, 128, (128, 256), dtype=torch.int8)
|
||||
layer.input_scale.data = torch.tensor([0.1], dtype=torch.float16)
|
||||
layer.input_offset.data = torch.tensor([0], dtype=torch.int8)
|
||||
|
||||
layer.weight_scale.data = torch.randn(128, 1, dtype=torch.bfloat16)
|
||||
layer.weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
|
||||
# w2_weight_offset is reshaped to (N, -1); any (N, 1) is fine
|
||||
layer.w2_weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
|
||||
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
mock_npu_format_cast.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user