From 6b80c5acbad0224477b8de4e202f2853ce992eb1 Mon Sep 17 00:00:00 2001 From: Zhu Yi Lin <116337067+GDzhu01@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:40:51 +0800 Subject: [PATCH] Fix W8A8 fused moe bug (#1529) ### What this PR does / why we need it? 1. drop some useless code for w8a8 fusedmoe 2. Add in8 kv cache check 3. Add more ut. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: zhuyilin <809721801@qq.com> Signed-off-by: tianyitang Co-authored-by: tianyitang --- tests/ut/attention/test_attention_v1.py | 499 ++++++++++++++ tests/ut/quantization/test_quant_config.py | 232 +++++++ tests/ut/quantization/test_quantizer.py | 124 ++++ tests/ut/quantization/test_w8a8.py | 730 +++++++++++++++++++++ vllm_ascend/attention/attention_v1.py | 6 +- vllm_ascend/quantization/quantizer.py | 3 +- vllm_ascend/quantization/w8a8.py | 80 +-- vllm_ascend/worker/model_runner_v1.py | 2 +- 8 files changed, 1623 insertions(+), 53 deletions(-) create mode 100644 tests/ut/attention/test_attention_v1.py create mode 100644 tests/ut/quantization/test_quant_config.py create mode 100644 tests/ut/quantization/test_quantizer.py create mode 100644 tests/ut/quantization/test_w8a8.py diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py new file mode 100644 index 0000000..3e0242b --- /dev/null +++ b/tests/ut/attention/test_attention_v1.py @@ -0,0 +1,499 @@ +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, + AscendAttentionBackendImpl, + AscendAttentionMetadataBuilder, + AscendAttentionState, + AscendMetadata, + CommonAttentionState) + + +class TestAscendAttentionBackend(TestBase): + + def test_get_name(self): + self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND") + + def test_get_impl_cls(self): + self.assertEqual(AscendAttentionBackend.get_impl_cls(), + AscendAttentionBackendImpl) + + def test_get_metadata_cls(self): + self.assertEqual(AscendAttentionBackend.get_metadata_cls(), + AscendMetadata) + + def test_get_state_cls(self): + self.assertEqual(AscendAttentionBackend.get_state_cls(), + CommonAttentionState) + + def test_get_builder_cls(self): + self.assertEqual(AscendAttentionBackend.get_builder_cls(), + AscendAttentionMetadataBuilder) + + @patch('vllm_ascend.attention.attention_v1.is_310p') + def test_get_kv_cache_shape_310p(self, mock_is_310p): + mock_is_310p.return_value = True + result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40) + self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16)) + + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) + def test_get_kv_cache_shape_not_310p(self, mock_is_310p): + result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40) + self.assertEqual(result, (2, 10, 20, 30, 40)) + + def test_get_bsh_kv_cache_shape(self): + result = AscendAttentionBackend.get_bsh_kv_cache_shape(10, 20, 30, 40) + self.assertEqual(result, (2, 10, 20, 30 * 40)) + + def test_swap_blocks(self): + src_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))] + dst_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))] + src_to_dst = torch.tensor([[0, 1], [2, 3]]) + AscendAttentionBackend.swap_blocks(src_kv_cache, dst_kv_cache, + src_to_dst) + self.assertTrue(torch.all(dst_kv_cache[0][1] == src_kv_cache[0][0])) + self.assertTrue(torch.all(dst_kv_cache[1][3] == src_kv_cache[1][2])) + + def test_copy_blocks(self): + kv_caches = [torch.zeros((10, 20)), torch.zeros((10, 20))] + src_to_dists = torch.tensor([[0, 1], [2, 3]]) + AscendAttentionBackend.copy_blocks(kv_caches, src_to_dists) + self.assertTrue(torch.all(kv_caches[0][1] == kv_caches[0][0])) + self.assertTrue(torch.all(kv_caches[1][3] == kv_caches[1][2])) + + +class TestAscendAttentionMetadataBuilder(TestBase): + + def setUp(self): + self.mock_runner = MagicMock() + self.builder = AscendAttentionMetadataBuilder(self.mock_runner) + + def test_reorder_batch(self): + mock_input_batch = MagicMock() + mock_scheduler_output = MagicMock() + + result = self.builder.reorder_batch(mock_input_batch, + mock_scheduler_output) + + self.assertFalse(result) + + @patch('vllm_ascend.attention.attention_v1.AscendMetadata') + @patch('torch_npu.npu_format_cast') + @patch('vllm_ascend.utils.nd_to_nz_2d') + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) + def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, + mock_npu_format_cast, + mock_ascend_metadata): + num_reqs = 2 + num_actual_tokens = 10 + max_query_len = 5 + common_prefix_len = 1 + + self.mock_runner.input_batch.block_table = [MagicMock()] + self.mock_runner.input_batch.block_table[ + 0].get_device_tensor.return_value = torch.zeros((10, 10)) + self.mock_runner.max_num_blocks_per_req = 10 + self.mock_runner.query_lens = torch.tensor([3, 4]) + self.mock_runner.seq_lens_cpu = torch.tensor([5, 6]) + self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) + self.mock_runner.device = 'cpu:0' + self.mock_runner.attn_mask = torch.ones((10, 10)) + self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache + self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7]) + + mock_nz_tensor = MagicMock() + mock_nd_to_nz_2d.return_value = mock_nz_tensor + mock_npu_format_cast.return_value = mock_nz_tensor + + self.builder.build(num_reqs, num_actual_tokens, max_query_len, + common_prefix_len) + + @patch('vllm_ascend.attention.attention_v1.AscendMetadata') + @patch('torch_npu.npu_format_cast') + @patch('vllm_ascend.utils.nd_to_nz_spec') + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) + @patch('vllm_ascend.attention.attention_v1.AscendAttentionState') + def test_build_chunked_prefill(self, mock_ascend_attention_state, + mock_is_310p, mock_nd_to_nz_spec, + mock_npu_format_cast, mock_ascend_metadata): + num_reqs = 3 + num_actual_tokens = 15 + max_query_len = 6 + + self.mock_runner.input_batch.block_table = [MagicMock()] + self.mock_runner.input_batch.block_table[ + 0].get_device_tensor.return_value = torch.zeros((10, 10)) + self.mock_runner.max_num_blocks_per_req = 10 + self.mock_runner.query_lens = torch.tensor([2, 3, 4]) + self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6]) + self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) + self.mock_runner.device = 'cpu:0' + self.mock_runner.attn_mask = torch.ones((15, 15)) + self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill + self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) + + mock_ascend_attention_state = MagicMock() + mock_ascend_attention_state.PrefillNoCache = 0 + + mock_nz_tensor = MagicMock() + mock_nd_to_nz_spec.return_value = mock_nz_tensor + mock_npu_format_cast.return_value = mock_nz_tensor + + self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0) + + @patch('vllm_ascend.attention.attention_v1.AscendMetadata') + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) + def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata): + num_reqs = 3 + num_actual_tokens = 15 + max_query_len = 6 + + self.mock_runner.input_batch.block_table = [MagicMock()] + self.mock_runner.input_batch.block_table[ + 0].get_device_tensor.return_value = torch.zeros((10, 10)) + self.mock_runner.max_num_blocks_per_req = 10 + self.mock_runner.query_lens = torch.tensor([2, 3, 4]) + self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6]) + self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) + self.mock_runner.device = 'cpu:0' + self.mock_runner.attn_mask = torch.ones((15, 15)) + self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill + self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) + + self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0) + + +class TestAscendAttentionBackendImpl(TestBase): + + def setUp(self): + self.layer = MagicMock() + self.layer.layer_name = "test_layer" + self.layer._k_scale_float = 1.0 + self.layer._v_scale_float = 1.0 + + self.attention_type = MagicMock() + self.attention_type.DECODER = "decoder" + self.attention_type.ENCODER = "encoder" + + self.attn_metadata = MagicMock() + self.attn_metadata.return_value = "1" + + self.layer_no_quant = MagicMock( + spec=['layer_name', '_k_scale_float', '_v_scale_float']) + self.layer_no_quant.layer_name = "test_layer" + self.layer_no_quant._k_scale_float = 1.0 + self.layer_no_quant._v_scale_float = 1.0 + + self.impl = AscendAttentionBackendImpl( + num_heads=8, + head_size=64, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + attn_type=self.attention_type.DECODER) + + self.impl_192 = AscendAttentionBackendImpl( + num_heads=8, + head_size=192, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + attn_type=self.attention_type.DECODER) + + self.impl_error = AscendAttentionBackendImpl(num_heads=8, + head_size=192, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + attn_type=None) + + @patch('torch.ops.vllm.unified_ascend_attention_with_output') + def test_forward_trace_flag_true(self, mock_unified_attention): + """Test forward pass when trace_flag is True""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 0, 0, 8, 64) + metadata = self.attn_metadata + layer = self.layer + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=True) + + mock_unified_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_paged_attention_splitfuse') + def test_forward_with_quant_method(self, mock_paged_attention): + """Test forward pass when layer has quant_method""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8) + + metadata = MagicMock() + metadata.num_actual_tokens = torch.randn(10, 8 * 64) + metadata.block_tables = torch.randn(10, 8 * 64) + metadata.seq_lens = torch.randn(10, 8 * 64) + metadata.attn_mask = torch.randn(10, 8 * 64) + metadata.query_lens = torch.randn(10, 8 * 64) + layer = self.layer + layer.quant_method = MagicMock() + layer.quant_method.apply.return_value = kv_cache + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + layer.quant_method.apply.assert_called_once() + assert output.shape == (10, 8 * 64) + + def test_forward_no_attn_metadata(self): + """Test forward pass when attn_metadata is None""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 0, 0, 8, 64) + layer = self.layer_no_quant + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + None, + trace_flag=False) + + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_flash_attention') + def test_forward_prefill_no_cache(self, mock_flash_attention, + mock_reshape_cache): + """Test forward pass in PrefillNoCache state""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.PrefillNoCache + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.seq_lens = torch.tensor([10]) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + # layer.quant_method.apply.return_value = metadata + print(self.layer_no_quant._v_scale_float) + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_reshape_cache.assert_called_once() + mock_flash_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_flash_attention_qlens') + def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens, + mock_npu_reshape_and_cache): + """Test forward pass in PrefillCacheHit state""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.PrefillCacheHit + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_flash_attention_qlens.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention') + def test_forward_decode_only(self, mock_paged_attention, + mock_npu_reshape_and_cache): + """Test forward pass in DecodeOnly state""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) + @patch('torch_npu._npu_reshape_and_cache') + @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') + def test_forward_head_size_192(self, mock_vanilla_prefill, + mock_npu_reshape_and_cache, mock_is_310p): + """Test forward pass when head_size is 192""" + + self.impl.head_size = 192 + query = torch.randn(10, 8 * 192) + key = torch.randn(10, 8 * 192) + value = torch.randn(10, 8 * 192) + kv_cache = torch.empty(2, 5, 128, 8, 192) + metadata = self.attn_metadata + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + mock_vanilla_prefill.return_value = MagicMock() + + def mock_tensor(data, device=None, **kwargs): + if device == "npu": + return metadata.attn_mask + return torch.tensor(data, **kwargs) + + with patch("torch.tensor", side_effect=mock_tensor): + output = self.impl_192.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_vanilla_prefill.assert_called_once() + assert output.shape == (10, 8 * 192) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention_splitfuse') + def test_forward_normal_v1_situation(self, mock_paged_attention, + mock_npu_reshape_and_cache): + """Test forward pass in normal V1 situation""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu.npu_format_cast') + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention_splitfuse') + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) + def test_forward_310p_device(self, mock_is_310p, mock_paged_attention, + mock_npu_reshape_and_cache, + mock_npu_format_cast): + """Test forward pass on 310P device""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + mock_npu_format_cast.return_value = metadata.attn_mask + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_reshape_and_cache') + def test_forward_raise_error(self, mock_paged_attention): + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + with self.assertRaises(NotImplementedError): + self.impl_error.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py new file mode 100644 index 0000000..3f20c01 --- /dev/null +++ b/tests/ut/quantization/test_quant_config.py @@ -0,0 +1,232 @@ +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + +from unittest.mock import MagicMock, patch + +import torch +from vllm.attention.layer import Attention +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) + +from tests.ut.base import TestBase +from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod, + AscendQuantConfig) + +ASCEND_QUATIZATION_METHOD = "ascend" + + +class TestAscendQuantConfig(TestBase): + + def setUp(self): + self.sample_config = { + "weight": "INT8", + "fa_quant_type": "C8", + "kv_quant_type": "C8", + "layer1.weight": "INT8", + "layer2.weight": "FLOAT", + "fused_layer.weight": "FLOAT", + "fused_layer.shard1.weight": "FLOAT", + "fused_layer.shard2.weight": "FLOAT", + "shard1.weight": "FLOAT", + "shard2.weight": "FLOAT", + } + self.ascend_config = AscendQuantConfig(self.sample_config) + self.ascend_config.packed_modules_mapping = None + + def test_init(self): + self.assertEqual(self.ascend_config.quant_description, + self.sample_config) + + def test_repr(self): + repr_str = repr(self.ascend_config) + self.assertTrue(repr_str.startswith("AscendQuantConfig:\n")) + + def test_get_name(self): + self.assertEqual(AscendQuantConfig.get_name(), + ASCEND_QUATIZATION_METHOD) + + def test_get_supported_act_dtypes(self): + supported_dtypes = AscendQuantConfig.get_supported_act_dtypes() + self.assertEqual(len(supported_dtypes), 3) + + def test_get_min_capability(self): + with self.assertRaises(NotImplementedError): + AscendQuantConfig.get_min_capability() + + def test_get_config_filenames(self): + filenames = AscendQuantConfig.get_config_filenames() + self.assertEqual(filenames, ["quant_model_description.json"]) + + def test_from_config(self): + config = AscendQuantConfig.from_config(self.sample_config) + self.assertIsInstance(config, AscendQuantConfig) + self.assertEqual(config.quant_description, self.sample_config) + + @patch('torch.npu.is_available') + def test_override_quantization_method(self, mock_is_available): + # Test when NPU is available + mock_is_available.return_value = True + result = AscendQuantConfig.override_quantization_method(None, None) + self.assertEqual(result, ASCEND_QUATIZATION_METHOD) + + # Test when NPU is not available + mock_is_available.return_value = False + result = AscendQuantConfig.override_quantization_method(None, None) + self.assertIsNone(result) + + def test_get_quant_method_for_linear(self): + linear_layer = MagicMock(spec=LinearBase) + # Test skipped layer + with patch.object(self.ascend_config, + 'is_layer_skipped_ascend', + return_value=True): + method = self.ascend_config.get_quant_method(linear_layer, ".attn") + self.assertIsInstance(method, UnquantizedLinearMethod) + + # Test quantized layer + with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ + patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: + + method = self.ascend_config.get_quant_method(linear_layer, ".attn") + self.assertIs(method, mock_ascend_linear.return_value) + mock_ascend_linear.assert_called_once_with( + self.ascend_config, ".attn", + self.ascend_config.packed_modules_mapping) + + def test_get_quant_method_for_attention(self): + attention_layer = MagicMock(spec=Attention) + with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', + return_value=MagicMock()) as mock_ascend_kvcache: + # Test with fa_quant_type + method = self.ascend_config.get_quant_method( + attention_layer, ".attn") + self.assertIs(method, mock_ascend_kvcache.return_value) + + with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', + return_value=MagicMock()) as mock_ascend_kvcache: + # Test with kv_quant_type + modified_config = {"kv_quant_type": "C8"} + config = AscendQuantConfig(modified_config) + config.packed_modules_mapping = None + method = config.get_quant_method(attention_layer, "attn") + self.assertIs(method, mock_ascend_kvcache.return_value) + + def test_get_quant_method_for_fused_moe(self): + fused_moe_layer = MagicMock(spec=FusedMoE) + + # Test skipped layer + with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \ + patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + method = self.ascend_config.get_quant_method( + fused_moe_layer, "moe_layer") + self.assertIs(method, mock_ascend_moe.return_value) + + # Test quantized layer + with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ + patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + method = self.ascend_config.get_quant_method( + fused_moe_layer, "moe_layer") + self.assertIs(method, mock_ascend_moe.return_value) + + def test_is_layer_skipped_ascend(self): + # Test non-fused layer that should be quantized + self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1")) + + # Test non-fused layer that should be skipped + self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2")) + + # Test fused layer + fused_mapping = {"fused_layer": ["shard1", "shard2"]} + self.assertTrue( + self.ascend_config.is_layer_skipped_ascend("fused_layer", + fused_mapping)) + + # Test inconsistent fused layer shards + bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"} + config = AscendQuantConfig(bad_config) + with self.assertRaises(ValueError): + config.is_layer_skipped_ascend("fused_layer", fused_mapping) + + def test_get_scaled_act_names(self): + self.assertEqual(self.ascend_config.get_scaled_act_names(), []) + + +class TestAscendKVCacheMethod(TestBase): + + def setUp(self): + # Setup common test fixtures + self.mock_quant_config = MagicMock(spec=AscendQuantConfig) + self.mock_quant_config.quant_description = {"some_config": "value"} + self.prefix = "attention_layer" + + # Mock the quantizer and quant_method + self.mock_quantizer = MagicMock() + self.mock_quant_method = MagicMock() + + # Patch the AscendQuantizer + self.quantizer_patcher = patch( + 'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer', + return_value=self.mock_quantizer) + self.mock_get_quantizer = self.quantizer_patcher.start() + + self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method + + # Create instance + self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config, + self.prefix) + + def tearDown(self): + self.quantizer_patcher.stop() + + def test_init(self): + """Test initialization with proper quantizer setup.""" + self.mock_get_quantizer.assert_called_once_with( + self.mock_quant_config.quant_description, self.prefix) + self.mock_quantizer.build_attention_method.assert_called_once() + + def test_create_weights(self): + """Test create_weights delegates to quant_method.""" + mock_layer = MagicMock() + self.kv_cache_method.create_weights(mock_layer) + self.mock_quant_method.create_weights.assert_called_once_with( + mock_layer) + + def test_process_weights_after_loading_with_method(self): + """Test process_weights when quant_method has the method.""" + mock_layer = MagicMock() + self.kv_cache_method.process_weights_after_loading(mock_layer) + self.mock_quant_method.process_weights_after_loading.assert_called_once_with( + mock_layer) + + def test_process_weights_after_loading_without_method(self): + """Test process_weights when quant_method lacks the method.""" + # Reset mock to remove the method + del self.mock_quant_method.process_weights_after_loading + mock_layer = MagicMock() + + # Should not raise exception + self.kv_cache_method.process_weights_after_loading(mock_layer) + + def test_apply_delegation(self): + """Test apply properly delegates to quant_method.""" + mock_layer = MagicMock() + mock_query = torch.randn(1, 32, 128) + mock_key = torch.randn(1, 32, 128) + mock_value = torch.randn(1, 32, 128) + mock_kv_cache = MagicMock() + mock_attn_metadata = MagicMock() + mock_scale = 1.0 + mock_output = torch.zeros(1, 32, 128) + mock_attn_type = MagicMock() + expected_result = torch.randn(1, 32, 128) + self.mock_quant_method.apply.return_value = expected_result + + result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key, + mock_value, mock_kv_cache, + mock_attn_metadata, mock_attn_type, + mock_scale, mock_output) + + self.mock_quant_method.apply.assert_called_once_with( + mock_layer, mock_query, mock_key, mock_value, mock_kv_cache, + mock_attn_metadata, mock_attn_type, mock_scale, mock_output) + self.assertTrue(torch.equal(result, expected_result)) diff --git a/tests/ut/quantization/test_quantizer.py b/tests/ut/quantization/test_quantizer.py new file mode 100644 index 0000000..a827364 --- /dev/null +++ b/tests/ut/quantization/test_quantizer.py @@ -0,0 +1,124 @@ +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + +from unittest.mock import MagicMock, patch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.quant_config import AscendQuantConfig +from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer, + W8A8Quantizer) + +SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"} + + +class TestGetQuantizer(TestBase): + + def setUp(self): + # Setup common test fixtures + self.supported_types = { + 'INT8': MagicMock(_instance=None), + 'FP16': MagicMock(_instance=None), + 'C8': MagicMock(_instance=None) + } + self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy() + SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types) + self.mock_quant_config = MagicMock(spec=AscendQuantConfig) + self.mock_quant_config.quant_description = {"some_config": "value"} + + def tearDown(self): + # Restore original supported types + SUPPORT_ASCEND_QUANTIZER_TYPE.clear() + SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types) + + def test_get_quantizer_fa(self): + """Test successful quantizer retrieval for different cases.""" + # Setup + quant_description = {'fa_quant_type': 'C8'} + prefix = '.attn' + expected_type = 'C8' + with patch.dict( + 'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', + SUPPORT_ASCEND_QUANTIZER_TYPE): + + result = VLLMAscendQuantizer.get_quantizer( + quant_description, + prefix, + packed_modules_mapping={"some": "mapping"}) + + # Verify + self.assertIsNotNone(result) + self.assertEqual(result, + self.supported_types[expected_type]._instance) + self.supported_types[expected_type].assert_called_once_with( + quant_description) + + def test_get_quantizer_kv(self): + """Test successful quantizer retrieval for different cases.""" + # Setup + quant_description = {'kv_quant_type': 'C8'} + prefix = '.attn' + expected_type = 'C8' + with patch.dict( + 'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', + SUPPORT_ASCEND_QUANTIZER_TYPE): + + result = VLLMAscendQuantizer.get_quantizer( + quant_description, + prefix, + packed_modules_mapping={"some": "mapping"}) + + # Verify + self.assertIsNotNone(result) + self.assertEqual(result, + self.supported_types[expected_type]._instance) + self.supported_types[expected_type].assert_called_once_with( + quant_description) + + def test_get_quantizer_linear(self): + """Test successful quantizer retrieval for different cases.""" + # Setup + quant_description = {'linear_type': 'INT8'} + prefix = 'nothing' + expected_type = 'INT8' + with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type', + return_value=expected_type), \ + patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE): + + result = VLLMAscendQuantizer.get_quantizer( + quant_description, + prefix, + packed_modules_mapping={"some": "mapping"}) + + # Verify + self.assertIsNotNone(result) + self.assertEqual(result, + self.supported_types[expected_type]._instance) + self.supported_types[expected_type].assert_called_once_with( + quant_description) + + +class TestW8A8Quantizer(TestBase): + + def setUp(self): + self.quantizer = W8A8Quantizer(quant_description={}) + + def test_build_linear_method(self): + with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod', + return_value=MagicMock()) as mock_linear: + result = self.quantizer.build_linear_method() + mock_linear.assert_called_once_with() + self.assertIsInstance(result, MagicMock) + + def test_build_moe_method(self): + with patch( + 'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod', + return_value=MagicMock()) as mock_linear: + result = self.quantizer.build_moe_method() + mock_linear.assert_called_once_with() + self.assertIsInstance(result, MagicMock) + + def test_build_attention_method(self): + with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod', + return_value=MagicMock()) as mock_linear: + result = self.quantizer.build_attention_method() + mock_linear.assert_called_once_with() + self.assertIsInstance(result, MagicMock) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py new file mode 100644 index 0000000..21e0626 --- /dev/null +++ b/tests/ut/quantization/test_w8a8.py @@ -0,0 +1,730 @@ +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod, + AscendW8A8FusedMoEMethod, + AscendW8A8LinearMethod, + fused_experts, native_grouped_topk, + quant_per_tensor, select_experts) + + +class TestQuantPerTensor(TestBase): + + @patch("torch_npu.npu_quantize") + def test_quant_per_tensor(self, mock_npu_quantize): + in_tensor = torch.randn(32, 128) + input_scale = torch.tensor(0.1) + input_offset = torch.tensor(0) + + expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8) + mock_npu_quantize.return_value = expected_output + + output = quant_per_tensor(in_tensor, input_scale, input_offset) + + mock_npu_quantize.assert_called_once_with( + in_tensor, + input_scale, + input_offset, + torch.qint8, + -1, + False, + ) + + self.assertTrue(torch.equal(output, expected_output)) + + +class TestAscendW8A8LinearMethod(TestBase): + + def setUp(self): + self.method = AscendW8A8LinearMethod() + + def test_get_weight(self): + weight = self.method.get_weight(10, 20) + self.assertEqual(weight['weight'].dtype, torch.int8) + self.assertEqual(weight['weight'].shape, (20, 10)) + + def test_get_pertensor_param(self): + params = self.method.get_pertensor_param(torch.bfloat16) + self.assertEqual(params['input_scale'].dtype, torch.bfloat16) + self.assertEqual(params['input_offset'].dtype, torch.int8) + self.assertEqual(params['input_scale'].shape, (1, )) + self.assertEqual(params['input_offset'].shape, (1, )) + + def test_get_perchannel_param(self): + params = self.method.get_perchannel_param(10, torch.bfloat16) + + self.assertEqual(params['quant_bias'].dtype, torch.int32) + self.assertEqual(params['deq_scale'].dtype, torch.float32) + self.assertEqual(params['weight_scale'].dtype, torch.bfloat16) + self.assertEqual(params['weight_offset'].dtype, torch.bfloat16) + self.assertEqual(params['quant_bias'].shape, (10, )) + self.assertEqual(params['deq_scale'].shape, (10, )) + self.assertEqual(params['weight_scale'].shape, (10, 1)) + self.assertEqual(params['weight_offset'].shape, (10, 1)) + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch("torch_npu.npu_quant_matmul") + def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, + mock_quant_per_tensor): + layer = MagicMock() + layer.aclnn_input_scale = 0.1 + layer.aclnn_input_offset = 0.2 + layer.weight = torch.randn(128, 256) + layer.deq_scale = 0.3 + + x = torch.randn(32, 128) + bias = torch.randn(256) + mock_quant_per_tensor.return_value = torch.randint(-128, + 127, + x.shape, + dtype=torch.int8) + + expected_y_output = torch.randn(32, 256) + mock_npu_quant_matmul.return_value = expected_y_output + + output = self.method.apply(layer, x, bias) + + expected_y_output += bias + self.assertTrue(torch.equal(output, expected_y_output)) + + @patch("torch_npu.npu_quant_matmul") + def test_apply_with_x_is_int8(self, mock_npu_quant_matmul): + layer = MagicMock() + layer.aclnn_input_scale = 0.1 + layer.aclnn_input_offset = 0.2 + layer.weight = torch.randn(128, 256) + layer.deq_scale = 0.3 + + x = torch.randint(-128, 127, (32, 128), dtype=torch.int8) + bias = torch.randn(256) + + expected_y_output = torch.randn(32, 256) + mock_npu_quant_matmul.return_value = expected_y_output + + output = self.method.apply(layer, x, bias) + expected_y_output += bias + self.assertTrue(torch.equal(output, expected_y_output)) + + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading(self, mock_npu_format_cast): + layer = MagicMock() + + layer.weight.data = torch.randn(128, 256) + layer.input_scale.data = torch.tensor([0.1]) + layer.input_offset.data = torch.tensor([0]) + layer.deq_scale = torch.tensor([0.5]) + layer.weight_scale.data = torch.randn(128, 1) + layer.weight_offset.data = torch.randn(128, 1) + + mock_npu_format_cast.return_value = MagicMock + self.method.process_weights_after_loading(layer) + + expected_offset = torch.tensor([0]).repeat(256).to(torch.int8) + self.assertTrue( + torch.equal(layer.aclnn_input_offset.data, expected_offset)) + self.assertFalse(layer.aclnn_input_offset.requires_grad) + + self.assertFalse(layer.deq_scale.requires_grad) + + self.assertEqual(layer.weight_scale.data.shape, (128, )) + self.assertEqual(layer.weight_offset.data.shape, (128, )) + + +class TestAscendW8A8FusedMoEMethod(TestBase): + + def setUp(self): + self.moe_method = AscendW8A8FusedMoEMethod() + self.num_experts = 4 + self.intermediate_size = 64 + self.hidden_size = 128 + self.dtype = torch.float32 + + def test_init(self): + self.assertTrue(self.moe_method.transpose_weight) + + def test_get_weight(self): + weights = self.moe_method.get_weight( + num_experts=self.num_experts, + intermediate_size_per_partition=self.intermediate_size, + hidden_sizes=self.hidden_size, + params_dtype=self.dtype) + + assert "w13_weight" in weights, f"w13_weight not in {weights}" + assert "w2_weight" in weights, f"w2_weight not in {weights}" + self.assertEqual( + weights["w13_weight"].shape, + (self.num_experts, 2 * self.intermediate_size, self.hidden_size)) + self.assertEqual( + weights["w2_weight"].shape, + (self.num_experts, self.hidden_size, self.intermediate_size)) + self.assertEqual(weights["w13_weight"].dtype, torch.int8) + self.assertEqual(weights["w2_weight"].dtype, torch.int8) + self.assertFalse(weights["w13_weight"].requires_grad) + self.assertFalse(weights["w2_weight"].requires_grad) + + def test_get_dynamic_quant_param(self): + quant_params = self.moe_method.get_dynamic_quant_param( + num_experts=self.num_experts, + intermediate_size_per_partition=self.intermediate_size, + hidden_sizes=self.hidden_size, + params_dtype=self.dtype) + + expected_params = [ + "w13_weight_scale", "w13_weight_offset", "w2_weight_scale", + "w2_weight_offset", "w2_deq_scale", "w13_deq_scale", + "w2_input_scale", "w13_input_scale", "w2_input_offset", + "w13_input_offset", "quant_bias" + ] + + for param in expected_params: + assert param in quant_params, f"{param} not in {quant_params}" + + # Check some sample shapes + self.assertEqual(quant_params["w13_weight_scale"].shape, + (self.num_experts, 2 * self.intermediate_size, 1)) + self.assertEqual(quant_params["w2_input_offset"].shape, + (self.num_experts, 1)) + self.assertEqual(quant_params["quant_bias"].shape, + (self.num_experts, self.hidden_size)) + + @patch('vllm_ascend.quantization.w8a8.select_experts') + @patch('vllm_ascend.quantization.w8a8.fused_experts') + def test_apply_with_other_expert_count(self, mock_fused_experts, + mock_select_experts): + # Setup + mock_layer = MagicMock() + x = torch.randn(32, self.hidden_size) + router_logits = torch.randn(32, 128) # 128 experts + top_k = 2 + + # Mock return values + mock_select_experts.return_value = (torch.randn(32, top_k), + torch.randint(0, 128, (32, top_k))) + mock_fused_experts.return_value = torch.randn(32, self.hidden_size) + + # Test + result = self.moe_method.apply(layer=mock_layer, + x=x, + router_logits=router_logits, + top_k=top_k, + renormalize=True, + global_num_experts=128) + + # Assertions + mock_select_experts.assert_called_once() + mock_fused_experts.assert_called_once() + self.assertEqual(result.shape, (32, self.hidden_size)) + + +class TestAscendC8KVCacheMethod(TestBase): + + def setUp(self): + self.layer = MagicMock() + self.layer.num_kv_heads = 4 + self.layer.head_size = 64 + self.layer.num_heads = 8 + self.layer._k_scale_float = 1.0 + self.layer._v_scale_float = 1.0 + self.method = AscendC8KVCacheMethod() + + self.attention_type = MagicMock() + self.attention_type.DECODER = "decoder" + self.attention_type.ENCODER = "encoder" + + def test_create_weights(self): + """测试 create_weights 是否正确注册参数""" + AscendC8KVCacheMethod.create_weights(self.layer) + + self.layer.register_parameter.assert_any_call("key_antiquant_scale", + unittest.mock.ANY) + self.layer.register_parameter.assert_any_call("value_antiquant_scale", + unittest.mock.ANY) + + calls = self.layer.register_parameter.call_args_list + + for call in calls: + args, kwargs = call + param = kwargs.get('parameter', args[1] if len(args) > 1 else None) + + expected_shape = (self.layer.num_kv_heads * self.layer.head_size, ) + self.assertEqual(param.shape, expected_shape) + + def test_process_weights_after_loading(self): + key_data = torch.ones(4 * 64) + value_data = torch.ones(4 * 64) * 2 + + self.layer.key_antiquant_scale.data = key_data + self.layer.value_antiquant_scale.data = value_data + + self.method.process_weights_after_loading(self.layer) + + self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256)) + self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1)) + self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2)) + + @patch('torch_npu.npu_scatter_nd_update_') + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + def test_apply_decode_only(self, mock_quant, mock_scatter): + + num_tokens = 2 + query = torch.randn(num_tokens, + self.layer.num_heads * self.layer.head_size) + key = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.DecodeOnly + attn_metadata.seq_lens = [10, 10] + attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]]) + attn_metadata.slot_mapping = torch.tensor([0, 1]) + attn_metadata.attn_mask = None + + block_size = 16 + key_cache = torch.empty(2, block_size, self.layer.num_kv_heads, + self.layer.head_size) + value_cache = torch.empty(2, block_size, self.layer.num_kv_heads, + self.layer.head_size) + kv_cache = (key_cache, value_cache) + + mock_quant.side_effect = [key, value] + + self.layer.key_antiquant_scale.data = torch.ones( + self.layer.num_kv_heads * self.layer.head_size) + self.layer.value_antiquant_scale.data = torch.ones( + self.layer.num_kv_heads * self.layer.head_size) + self.method.process_weights_after_loading(self.layer) + + expected_output = torch.randn( + num_tokens, self.layer.num_heads * self.layer.head_size) + with patch('torch_npu.npu_incre_flash_attention', + return_value=expected_output): + result = self.method.apply(self.layer, query, key, value, kv_cache, + attn_metadata, + self.attention_type.DECODER, 1.0, + output) + + self.assertEqual(mock_quant.call_count, 2) + self.assertEqual(mock_scatter.call_count, 2) + self.assertTrue(torch.equal(result, expected_output)) + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch('torch_npu._npu_flash_attention') + def test_apply_prefill_no_cache(self, mock_flash, mock_quant): + """Test apply method in prefill no-cache mode""" + + num_tokens = 2 + query = torch.randn(num_tokens, + self.layer.num_heads * self.layer.head_size) + key = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.PrefillNoCache + attn_metadata.seq_lens = [10, 10] + attn_metadata.attn_mask = torch.ones(2, 2) + + kv_cache = (torch.tensor([]), torch.tensor([])) + mock_quant.return_value = key + + result = self.method.apply(self.layer, query, key, value, kv_cache, + attn_metadata, self.attention_type.DECODER, + 1.0, output) + + # Check that flash attention was called + mock_flash.assert_called_once() + + # Check output shape + self.assertEqual( + result.shape, + (num_tokens, self.layer.num_heads * self.layer.head_size)) + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + def test_apply_unsupported_attention_type(self, mock_quant): + + query = torch.randn(1, self.layer.num_heads * self.layer.head_size) + key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + mock_quant.return_value = key + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.PrefillNoCache + + with self.assertRaises(NotImplementedError) as cm: + self.method.apply(self.layer, query, key, value, (None, None), + attn_metadata, self.attention_type.ENCODER, 1.0, + output) + + assert "Encoder self-attention" in str( + cm.exception), f"Encoder self-attention not in {str(cm.exception)}" + assert "not implemented" in str( + cm.exception), f"not implemented not in{str(cm.exception)}" + + mock_quant.assert_not_called() + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + def test_apply_unsupported_attention_state(self, mock_quant): + """Test apply with unsupported attention state""" + query = torch.randn(1, self.layer.num_heads * self.layer.head_size) + key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.PrefillCacheHit + mock_quant.return_value = key + kv_cache = (torch.tensor([]), torch.tensor([])) + + with self.assertRaises(NotImplementedError): + self.method.apply(self.layer, query, key, value, kv_cache, + attn_metadata, self.attention_type.DECODER, 1.0, + output) + + +class TestFusedExperts(TestBase): + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch('vllm_ascend.quantization.w8a8.get_ep_group') + @patch('torch_npu.npu_moe_init_routing_v2') + @patch('torch_npu.npu_grouped_matmul') + @patch('torch_npu.npu_swiglu') + @patch('torch_npu.npu_moe_finalize_routing') + def test_fused_experts_with_expert_map(self, mock_finalize, mock_swiglu, + mock_group_matmul, + mock_init_routing, + mock_get_ep_group, + mock_quant_per_tensor): + num_tokens = 32 + hidden_size = 128 + intermediate_size = 256 + num_experts = 4 + top_k = 2 + + hidden_states = torch.randn(num_tokens, hidden_size) + + w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size) + w1_scale = torch.tensor([0.1]) + w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]]) + w1_input_offset = torch.tensor([0]) + + w2 = torch.randn(num_experts, hidden_size, intermediate_size) + w2_scale = torch.tensor([0.1]) + w2_input_scale = torch.tensor([0.2]) + w2_input_offset = torch.tensor([0]) + + topk_weights = torch.rand(num_tokens, top_k) + topk_ids = torch.randint(0, num_experts, (num_tokens, top_k)) + expert_map = torch.arange(num_experts) + + mock_get_ep_group.return_value.world_size = 8 + + mock_quant_per_tensor.return_value = torch.randint(-128, + 127, + hidden_states.shape, + dtype=torch.int8) + + mock_init_routing.return_value = (torch.randn(num_tokens * top_k, + hidden_size), + torch.arange(num_tokens * top_k), + torch.tensor([num_tokens // 2] * 2), + torch.tensor(1.0)) + + mock_group_matmul.side_effect = [[ + torch.randn(num_tokens * top_k, intermediate_size * 2) + ], [torch.randn(num_tokens * top_k, hidden_size)]] + + mock_swiglu.return_value = torch.randn(num_tokens * top_k, + intermediate_size) + + expected_output = torch.randn(num_tokens, hidden_size) + mock_finalize.return_value = expected_output + + output = fused_experts( + hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w1_input_scale=w1_input_scale, + w1_input_offset=w1_input_offset, + w2=w2, + w2_scale=w2_scale, + w2_input_scale=w2_input_scale, + w2_input_offset=w2_input_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=num_experts, + expert_map=expert_map, + ) + + mock_init_routing.assert_called_once() + + self.assertEqual(mock_group_matmul.call_count, 2) + + self.assertEqual(output.shape, (num_tokens, hidden_size)) + + mock_finalize.assert_called_once() + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch('vllm_ascend.quantization.w8a8.get_ep_group') + @patch('torch_npu.npu_grouped_matmul') + @patch('torch_npu.npu_swiglu') + def test_fused_experts_without_expert_map(self, mock_swiglu, + mock_group_matmul, + mock_get_ep_group, + mock_quant_per_tensor): + num_tokens = 16 + hidden_size = 64 + intermediate_size = 128 + num_experts = 8 + top_k = 1 + + hidden_states = torch.randn(num_tokens, hidden_size) + w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size) + w2 = torch.randn(num_experts, hidden_size, intermediate_size) + topk_weights = torch.rand(num_tokens, top_k) + topk_ids = torch.randint(0, num_experts, (num_tokens, top_k)) + + mock_get_ep_group.return_value.world_size = 8 + + mock_quant_per_tensor.return_value = torch.randint(-128, + 127, + hidden_states.shape, + dtype=torch.int8) + mock_group_matmul.side_effect = [[ + torch.randn(num_tokens * top_k, intermediate_size * 2) + ], [torch.randn(num_tokens * top_k, hidden_size)]] + mock_swiglu.return_value = torch.randn(num_tokens * top_k, + intermediate_size) + with self.assertRaises(NotImplementedError): + fused_experts( + hidden_states=hidden_states, + w1=w1, + w1_scale=torch.tensor([0.1]), + w1_input_scale=torch.tensor([[0.2, 0.2], [0.2, 0.2]]), + w1_input_offset=torch.tensor([0]), + w2=w2, + w2_scale=torch.tensor([0.1]), + w2_input_scale=torch.tensor([0.1]), + w2_input_offset=torch.tensor([0]), + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=num_experts, + expert_map=None, + ) + + +class TestSelectExperts(TestBase): + + def setUp(self): + # Common test data + self.num_tokens = 10 + self.hidden_size = 32 + self.num_experts = 8 + self.top_k = 2 + + self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) + self.router_logits = torch.randn(self.num_tokens, self.num_experts) + + def test_softmax_scoring(self): + """Test softmax scoring function""" + + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="softmax") + + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + + def test_sigmoid_scoring(self): + """Test sigmoid scoring function""" + + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid") + + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + + def test_invalid_scoring_func(self): + """Test invalid scoring function raises ValueError""" + with self.assertRaises(ValueError): + select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="invalid_func") + + @patch('torch.topk') + def test_grouped_topk(self, mock_topk): + """Test grouped topk functionality""" + mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.long)) + + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2) + + mock_topk.assert_called() + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.dtype, torch.int32) + + @patch('vllm_ascend.quantization.w8a8.native_grouped_topk') + def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): + """Test grouped topk with expert score correction bias""" + mock_grouped_topk.return_value = torch.ones(self.num_tokens, + self.num_experts) + + e_score_correction_bias = torch.randn(self.num_experts) + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2, + e_score_correction_bias=e_score_correction_bias) + + mock_grouped_topk.assert_called_once() + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + + def test_custom_routing_function(self): + """Test custom routing function""" + mock_custom_routing = MagicMock() + mock_custom_routing.return_value = (torch.ones(self.num_tokens, + self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.int32)) + + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + custom_routing_function=mock_custom_routing) + + mock_custom_routing.assert_called_once() + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.dtype, torch.int32) + + @patch('torch.topk') + def test_renormalize(self, mock_topk): + """Test weight renormalization""" + mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.long)) + + weights, _ = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=True, + ) + + # Check if weights are normalized (sum to 1 for each token) + sums = weights.sum(dim=-1) + self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) + + @patch('torch.topk') + def test_output_dtypes(self, mock_topk): + """Test output dtypes""" + mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.long)) + + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + ) + + self.assertEqual(weights.dtype, self.hidden_states.dtype) + self.assertEqual(ids.dtype, torch.int32) + + +class TestNativeGroupedTopkPartialMock(TestBase): + + def test_basic_group_selection(self): + topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6], + [0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1], + [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], + [0.9, 0.1, 0.8, 0.2, 0.7, 0.3, 0.6, 0.4]], + dtype=torch.float32) + + expected_topk_indices = torch.tensor([[0, 1], [1, 0], [0, 1], [0, 1]]) + + with patch('torch.topk', + return_value=(None, expected_topk_indices)) as mock_topk: + result = native_grouped_topk(topk_weights=topk_weights, + num_expert_group=2, + topk_group=2) + + mock_topk.assert_called_once() + + expected_result = topk_weights + self.assertTrue(torch.allclose(result, expected_result)) + + def test_partial_group_selection(self): + + topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6], + [0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1]]) + + expected_topk_indices = torch.tensor([[0], [1]]) + + with patch('torch.topk', return_value=(None, expected_topk_indices)): + result = native_grouped_topk(topk_weights=topk_weights, + num_expert_group=2, + topk_group=1) + + expected_result = torch.tensor( + [[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.9, 0.1]]) + self.assertTrue(torch.allclose(result, expected_result)) + + def test_single_group(self): + topk_weights = torch.tensor([[0.1, 0.9, 0.2], [0.8, 0.3, 0.7]]) + + expected_topk_indices = torch.tensor([[0], [0]]) + + with patch('torch.topk', return_value=(None, expected_topk_indices)): + result = native_grouped_topk(topk_weights=topk_weights, + num_expert_group=1, + topk_group=1) + self.assertTrue(result.numel() > 0) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 0255c53..7d7f488 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -274,6 +274,8 @@ class AscendAttentionBackendImpl(AttentionImpl): shape = [batch_size * seq_len, num_heads, head_size] """ num_tokens = query.shape[0] + use_kv_cache_int8 = kv_cache.numel( + ) > 0 and kv_cache[0].dtype == torch.int8 if output is None: output = torch.empty(num_tokens, self.num_heads, @@ -289,7 +291,7 @@ class AscendAttentionBackendImpl(AttentionImpl): output=output, layer_name=layer.layer_name) - elif hasattr(layer, 'quant_method'): + elif hasattr(layer, 'quant_method') and use_kv_cache_int8: output = layer.quant_method.apply(layer, query, key, value, kv_cache, attn_metadata, self.attn_type, self.scale, @@ -429,7 +431,7 @@ class AscendAttentionBackendImpl(AttentionImpl): out=output) # to make in-place change to the output tensor - if hasattr(layer, 'quant_method'): + if hasattr(layer, 'quant_method') and use_kv_cache_int8: output = output.view(num_tokens, self.num_heads, self.head_size) ori_output[:, :, :] = output[:num_tokens, :, :] return output.view(num_tokens, self.hidden_size) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 81326fb..02f4486 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -251,7 +251,8 @@ class VLLMAscendQuantizer: # Attention if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): quant_type = quant_description['fa_quant_type'] - if '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): + # Use KVCache int8 + elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): quant_type = quant_description['kv_quant_type'] # Linear else: diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 6a2f403..0746150 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -15,7 +15,6 @@ # limitations under the License. # -import os from typing import Any, Callable, Dict, Optional import torch @@ -219,53 +218,34 @@ class AscendW8A8FusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, - bias=e_score_correction_bias, - k_group=topk_group, - group_count=num_expert_group, - group_select_mode=1, - renorm=0, - norm_type=1, - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts, - ) + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) - if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill: - raise NotImplementedError("W8A8FusedMoe are not " - "mplemented for VLLM_ENABLE_MC2") - - else: - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w1_input_scale=layer.w13_input_scale, - w1_input_offset=layer.w13_input_offset, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - w2_input_scale=layer.w2_input_scale, - w2_input_offset=layer.w2_input_offset, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - global_num_experts=global_num_experts, - expert_map=expert_map) + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w1_input_scale=layer.w13_input_scale, + w1_input_offset=layer.w13_input_offset, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, + w2_input_offset=layer.w2_input_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=global_num_experts, + expert_map=expert_map) def process_weights_after_loading(self, layer): # torch.npu.config.allow_internal_format = True @@ -299,8 +279,10 @@ class AscendW8A8FusedMoEMethod: torch.int8) # NZ - # layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, 29).contiguous() - # layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, 29).contiguous() + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, 29).contiguous() + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, 29).contiguous() class AscendC8KVCacheMethod: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9cbb2c3..50d610e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2194,7 +2194,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): block_size=self.block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=attn_module.dtype, + dtype=self.kv_cache_dtype, use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):