[Quantization]300I Duo support w8a8 quantization (#1560)
### What this PR does / why we need it? This pr supports w8a8 on 300I Duo platform. The main change is to use `npu_quant_grouped_matmul_dequant` to replace `npu_grouped_matmul`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? offline inference on 310p runs normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -10,7 +10,8 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
fused_experts, native_grouped_topk,
|
||||
fused_experts, fused_experts_310p,
|
||||
native_grouped_topk,
|
||||
quant_per_tensor, select_experts)
|
||||
|
||||
|
||||
@@ -111,6 +112,25 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_is_310p(self, mock_npu_quant_matmul, mock_is_310p):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
layer.aclnn_input_offset = 0.2
|
||||
layer.weight = torch.randn(128, 256)
|
||||
layer.deq_scale = 0.3
|
||||
|
||||
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
bias = torch.randn(256)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_quant_matmul.return_value = expected_y_output
|
||||
|
||||
output = self.method.apply(layer, x, bias)
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading(self, mock_npu_format_cast):
|
||||
layer = MagicMock()
|
||||
@@ -221,6 +241,36 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
mock_fused_experts.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
@patch('vllm_ascend.quantization.w8a8.select_experts')
|
||||
@patch('vllm_ascend.quantization.w8a8.fused_experts_310p')
|
||||
def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts,
|
||||
mock_is_310p):
|
||||
# Setup
|
||||
mock_layer = MagicMock()
|
||||
x = torch.randn(32, self.hidden_size)
|
||||
router_logits = torch.randn(32, 128) # 128 experts
|
||||
top_k = 2
|
||||
|
||||
# Mock return values
|
||||
mock_select_experts.return_value = (torch.randn(32, top_k),
|
||||
torch.randint(0, 128, (32, top_k)))
|
||||
mock_fused_experts_310p.return_value = torch.randn(
|
||||
32, self.hidden_size)
|
||||
|
||||
# Test
|
||||
result = self.moe_method.apply(layer=mock_layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
global_num_experts=128)
|
||||
|
||||
# Assertions
|
||||
mock_select_experts.assert_called_once()
|
||||
mock_fused_experts_310p.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
|
||||
class TestAscendC8KVCacheMethod(TestBase):
|
||||
|
||||
@@ -255,7 +305,22 @@ class TestAscendC8KVCacheMethod(TestBase):
|
||||
expected_shape = (self.layer.num_kv_heads * self.layer.head_size, )
|
||||
self.assertEqual(param.shape, expected_shape)
|
||||
|
||||
def test_process_weights_after_loading(self):
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=False)
|
||||
def test_process_weights_after_loading_not_310p(self, mock_is_310p):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
self.layer.key_antiquant_scale.data = key_data
|
||||
self.layer.value_antiquant_scale.data = value_data
|
||||
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
def test_process_weights_after_loading_is_310p(self, mock_is_310p):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
@@ -527,6 +592,67 @@ class TestFusedExperts(TestBase):
|
||||
)
|
||||
|
||||
|
||||
class TestFusedExperts310(TestBase):
|
||||
|
||||
@patch('torch_npu.npu_quant_grouped_matmul_dequant')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
def test_fused_experts_310p_with_expert_map(self, mock_swiglu,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor,
|
||||
mock_matmul_dequant):
|
||||
num_tokens = 32
|
||||
hidden_size = 128
|
||||
intermediate_size = 256
|
||||
num_experts = 4
|
||||
top_k = 1
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w1_scale = torch.tensor([0.1])
|
||||
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
|
||||
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
w2_scale = torch.tensor([0.1])
|
||||
w2_input_scale = torch.tensor([0.2])
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
expert_map = torch.arange(num_experts)
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 1
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
|
||||
mock_matmul_dequant.return_value = hidden_states
|
||||
|
||||
output = fused_experts_310p(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w1_input_scale=w1_input_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (num_tokens, hidden_size))
|
||||
self.assertEqual(mock_matmul_dequant.call_count, 2)
|
||||
|
||||
|
||||
class TestSelectExperts(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
@@ -102,6 +104,79 @@ class TestUtils(unittest.TestCase):
|
||||
output_tensor = utils.aligned_16(input_tensor)
|
||||
self.assertEqual(output_tensor.shape[0], 32)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
@mock.patch('vllm_ascend.utils.is_310p')
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
def test_maybe_converting_weight_acl_format(self, mock_get_config,
|
||||
mock_310p, mock_npu_cast,
|
||||
mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_310p.return_value = True
|
||||
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_get_format.return_value = 1
|
||||
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
self.assertEqual(fused_moe.w13_weight.data, 1)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
@mock.patch('vllm_ascend.utils.is_310p')
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
def test_maybe_converting_weight_acl_format_format_true(
|
||||
self, mock_get_config, mock_310p, mock_npu_cast, mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_310p.return_value = True
|
||||
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
@mock.patch('vllm_ascend.utils.is_310p', return_value=False)
|
||||
def test_maybe_converting_weight_acl_format_not_310_not_graph(
|
||||
self, mock_310p, mock_get_config):
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_config.return_value = mock_config
|
||||
|
||||
mock_constant = mock.MagicMock()
|
||||
|
||||
mock_model = mock.MagicMock()
|
||||
utils.maybe_converting_weight_acl_format(mock_model, mock_constant)
|
||||
|
||||
@mock.patch('importlib.util.find_spec')
|
||||
@mock.patch('importlib.import_module')
|
||||
def test_try_register_lib(self, mock_import_module, mock_find_spec):
|
||||
@@ -111,23 +186,17 @@ class TestUtils(unittest.TestCase):
|
||||
lib_name = "existing_lib"
|
||||
lib_info = "Library found and imported successfully"
|
||||
utils.try_register_lib(lib_name, lib_info)
|
||||
mock_find_spec.assert_called_once_with(lib_name)
|
||||
mock_import_module.assert_called_once_with(lib_name)
|
||||
|
||||
# Can't find lib
|
||||
mock_find_spec.return_value = None
|
||||
lib_name = "non_existing_lib"
|
||||
utils.try_register_lib(lib_name)
|
||||
self.assertEqual(2, mock_find_spec.call_count)
|
||||
self.assertEqual(1, mock_import_module.call_count)
|
||||
|
||||
# import error
|
||||
mock_find_spec.return_value = mock.MagicMock()
|
||||
mock_import_module.side_effect = ImportError("import error")
|
||||
lib_name = "error_lib"
|
||||
utils.try_register_lib(lib_name)
|
||||
self.assertEqual(3, mock_find_spec.call_count)
|
||||
self.assertEqual(2, mock_import_module.call_count)
|
||||
|
||||
def test_enable_custom_op(self):
|
||||
result = utils.enable_custom_op()
|
||||
|
||||
Reference in New Issue
Block a user