[Quantization]300I Duo support w8a8 quantization (#1560)
### What this PR does / why we need it? This pr supports w8a8 on 300I Duo platform. The main change is to use `npu_quant_grouped_matmul_dequant` to replace `npu_grouped_matmul`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? offline inference on 310p runs normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
@@ -102,6 +104,79 @@ class TestUtils(unittest.TestCase):
|
||||
output_tensor = utils.aligned_16(input_tensor)
|
||||
self.assertEqual(output_tensor.shape[0], 32)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
@mock.patch('vllm_ascend.utils.is_310p')
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
def test_maybe_converting_weight_acl_format(self, mock_get_config,
|
||||
mock_310p, mock_npu_cast,
|
||||
mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_310p.return_value = True
|
||||
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_get_format.return_value = 1
|
||||
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
self.assertEqual(fused_moe.w13_weight.data, 1)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
@mock.patch('vllm_ascend.utils.is_310p')
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
def test_maybe_converting_weight_acl_format_format_true(
|
||||
self, mock_get_config, mock_310p, mock_npu_cast, mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_310p.return_value = True
|
||||
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
@mock.patch('vllm_ascend.utils.is_310p', return_value=False)
|
||||
def test_maybe_converting_weight_acl_format_not_310_not_graph(
|
||||
self, mock_310p, mock_get_config):
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_config.return_value = mock_config
|
||||
|
||||
mock_constant = mock.MagicMock()
|
||||
|
||||
mock_model = mock.MagicMock()
|
||||
utils.maybe_converting_weight_acl_format(mock_model, mock_constant)
|
||||
|
||||
@mock.patch('importlib.util.find_spec')
|
||||
@mock.patch('importlib.import_module')
|
||||
def test_try_register_lib(self, mock_import_module, mock_find_spec):
|
||||
@@ -111,23 +186,17 @@ class TestUtils(unittest.TestCase):
|
||||
lib_name = "existing_lib"
|
||||
lib_info = "Library found and imported successfully"
|
||||
utils.try_register_lib(lib_name, lib_info)
|
||||
mock_find_spec.assert_called_once_with(lib_name)
|
||||
mock_import_module.assert_called_once_with(lib_name)
|
||||
|
||||
# Can't find lib
|
||||
mock_find_spec.return_value = None
|
||||
lib_name = "non_existing_lib"
|
||||
utils.try_register_lib(lib_name)
|
||||
self.assertEqual(2, mock_find_spec.call_count)
|
||||
self.assertEqual(1, mock_import_module.call_count)
|
||||
|
||||
# import error
|
||||
mock_find_spec.return_value = mock.MagicMock()
|
||||
mock_import_module.side_effect = ImportError("import error")
|
||||
lib_name = "error_lib"
|
||||
utils.try_register_lib(lib_name)
|
||||
self.assertEqual(3, mock_find_spec.call_count)
|
||||
self.assertEqual(2, mock_import_module.call_count)
|
||||
|
||||
def test_enable_custom_op(self):
|
||||
result = utils.enable_custom_op()
|
||||
|
||||
Reference in New Issue
Block a user