Fix W8A8 fused moe bug (#1529)
### What this PR does / why we need it? 1. drop some useless code for w8a8 fusedmoe 2. Add in8 kv cache check 3. Add more ut. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: zhuyilin <809721801@qq.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
499
tests/ut/attention/test_attention_v1.py
Normal file
499
tests/ut/attention/test_attention_v1.py
Normal file
@@ -0,0 +1,499 @@
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendAttentionBackendImpl,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata,
|
||||
CommonAttentionState)
|
||||
|
||||
|
||||
class TestAscendAttentionBackend(TestBase):
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND")
|
||||
|
||||
def test_get_impl_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
|
||||
AscendAttentionBackendImpl)
|
||||
|
||||
def test_get_metadata_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_metadata_cls(),
|
||||
AscendMetadata)
|
||||
|
||||
def test_get_state_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_state_cls(),
|
||||
CommonAttentionState)
|
||||
|
||||
def test_get_builder_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
|
||||
AscendAttentionMetadataBuilder)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p')
|
||||
def test_get_kv_cache_shape_310p(self, mock_is_310p):
|
||||
mock_is_310p.return_value = True
|
||||
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
def test_get_kv_cache_shape_not_310p(self, mock_is_310p):
|
||||
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 20, 30, 40))
|
||||
|
||||
def test_get_bsh_kv_cache_shape(self):
|
||||
result = AscendAttentionBackend.get_bsh_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 20, 30 * 40))
|
||||
|
||||
def test_swap_blocks(self):
|
||||
src_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
|
||||
dst_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
|
||||
src_to_dst = torch.tensor([[0, 1], [2, 3]])
|
||||
AscendAttentionBackend.swap_blocks(src_kv_cache, dst_kv_cache,
|
||||
src_to_dst)
|
||||
self.assertTrue(torch.all(dst_kv_cache[0][1] == src_kv_cache[0][0]))
|
||||
self.assertTrue(torch.all(dst_kv_cache[1][3] == src_kv_cache[1][2]))
|
||||
|
||||
def test_copy_blocks(self):
|
||||
kv_caches = [torch.zeros((10, 20)), torch.zeros((10, 20))]
|
||||
src_to_dists = torch.tensor([[0, 1], [2, 3]])
|
||||
AscendAttentionBackend.copy_blocks(kv_caches, src_to_dists)
|
||||
self.assertTrue(torch.all(kv_caches[0][1] == kv_caches[0][0]))
|
||||
self.assertTrue(torch.all(kv_caches[1][3] == kv_caches[1][2]))
|
||||
|
||||
|
||||
class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_runner = MagicMock()
|
||||
self.builder = AscendAttentionMetadataBuilder(self.mock_runner)
|
||||
|
||||
def test_reorder_batch(self):
|
||||
mock_input_batch = MagicMock()
|
||||
mock_scheduler_output = MagicMock()
|
||||
|
||||
result = self.builder.reorder_batch(mock_input_batch,
|
||||
mock_scheduler_output)
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('vllm_ascend.utils.nd_to_nz_2d')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
|
||||
mock_npu_format_cast,
|
||||
mock_ascend_metadata):
|
||||
num_reqs = 2
|
||||
num_actual_tokens = 10
|
||||
max_query_len = 5
|
||||
common_prefix_len = 1
|
||||
|
||||
self.mock_runner.input_batch.block_table = [MagicMock()]
|
||||
self.mock_runner.input_batch.block_table[
|
||||
0].get_device_tensor.return_value = torch.zeros((10, 10))
|
||||
self.mock_runner.max_num_blocks_per_req = 10
|
||||
self.mock_runner.query_lens = torch.tensor([3, 4])
|
||||
self.mock_runner.seq_lens_cpu = torch.tensor([5, 6])
|
||||
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
|
||||
self.mock_runner.device = 'cpu:0'
|
||||
self.mock_runner.attn_mask = torch.ones((10, 10))
|
||||
self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache
|
||||
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7])
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_nd_to_nz_2d.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(num_reqs, num_actual_tokens, max_query_len,
|
||||
common_prefix_len)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('vllm_ascend.utils.nd_to_nz_spec')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
|
||||
def test_build_chunked_prefill(self, mock_ascend_attention_state,
|
||||
mock_is_310p, mock_nd_to_nz_spec,
|
||||
mock_npu_format_cast, mock_ascend_metadata):
|
||||
num_reqs = 3
|
||||
num_actual_tokens = 15
|
||||
max_query_len = 6
|
||||
|
||||
self.mock_runner.input_batch.block_table = [MagicMock()]
|
||||
self.mock_runner.input_batch.block_table[
|
||||
0].get_device_tensor.return_value = torch.zeros((10, 10))
|
||||
self.mock_runner.max_num_blocks_per_req = 10
|
||||
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
|
||||
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
|
||||
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
|
||||
self.mock_runner.device = 'cpu:0'
|
||||
self.mock_runner.attn_mask = torch.ones((15, 15))
|
||||
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
|
||||
|
||||
mock_ascend_attention_state = MagicMock()
|
||||
mock_ascend_attention_state.PrefillNoCache = 0
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_nd_to_nz_spec.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
|
||||
num_reqs = 3
|
||||
num_actual_tokens = 15
|
||||
max_query_len = 6
|
||||
|
||||
self.mock_runner.input_batch.block_table = [MagicMock()]
|
||||
self.mock_runner.input_batch.block_table[
|
||||
0].get_device_tensor.return_value = torch.zeros((10, 10))
|
||||
self.mock_runner.max_num_blocks_per_req = 10
|
||||
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
|
||||
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
|
||||
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
|
||||
self.mock_runner.device = 'cpu:0'
|
||||
self.mock_runner.attn_mask = torch.ones((15, 15))
|
||||
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
|
||||
|
||||
self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0)
|
||||
|
||||
|
||||
class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
self.layer._k_scale_float = 1.0
|
||||
self.layer._v_scale_float = 1.0
|
||||
|
||||
self.attention_type = MagicMock()
|
||||
self.attention_type.DECODER = "decoder"
|
||||
self.attention_type.ENCODER = "encoder"
|
||||
|
||||
self.attn_metadata = MagicMock()
|
||||
self.attn_metadata.return_value = "1"
|
||||
|
||||
self.layer_no_quant = MagicMock(
|
||||
spec=['layer_name', '_k_scale_float', '_v_scale_float'])
|
||||
self.layer_no_quant.layer_name = "test_layer"
|
||||
self.layer_no_quant._k_scale_float = 1.0
|
||||
self.layer_no_quant._v_scale_float = 1.0
|
||||
|
||||
self.impl = AscendAttentionBackendImpl(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="float16",
|
||||
attn_type=self.attention_type.DECODER)
|
||||
|
||||
self.impl_192 = AscendAttentionBackendImpl(
|
||||
num_heads=8,
|
||||
head_size=192,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="float16",
|
||||
attn_type=self.attention_type.DECODER)
|
||||
|
||||
self.impl_error = AscendAttentionBackendImpl(num_heads=8,
|
||||
head_size=192,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="float16",
|
||||
attn_type=None)
|
||||
|
||||
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
|
||||
def test_forward_trace_flag_true(self, mock_unified_attention):
|
||||
"""Test forward pass when trace_flag is True"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
layer = self.layer
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=True)
|
||||
|
||||
mock_unified_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_paged_attention_splitfuse')
|
||||
def test_forward_with_quant_method(self, mock_paged_attention):
|
||||
"""Test forward pass when layer has quant_method"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
|
||||
metadata.block_tables = torch.randn(10, 8 * 64)
|
||||
metadata.seq_lens = torch.randn(10, 8 * 64)
|
||||
metadata.attn_mask = torch.randn(10, 8 * 64)
|
||||
metadata.query_lens = torch.randn(10, 8 * 64)
|
||||
layer = self.layer
|
||||
layer.quant_method = MagicMock()
|
||||
layer.quant_method.apply.return_value = kv_cache
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
layer.quant_method.apply.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
def test_forward_no_attn_metadata(self):
|
||||
"""Test forward pass when attn_metadata is None"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
None,
|
||||
trace_flag=False)
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
||||
mock_reshape_cache):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
# layer.quant_method.apply.return_value = metadata
|
||||
print(self.layer_no_quant._v_scale_float)
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_reshape_cache.assert_called_once()
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention_qlens')
|
||||
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_flash_attention_qlens.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
def test_forward_decode_only(self, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
||||
mock_npu_reshape_and_cache, mock_is_310p):
|
||||
"""Test forward pass when head_size is 192"""
|
||||
|
||||
self.impl.head_size = 192
|
||||
query = torch.randn(10, 8 * 192)
|
||||
key = torch.randn(10, 8 * 192)
|
||||
value = torch.randn(10, 8 * 192)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
mock_vanilla_prefill.return_value = MagicMock()
|
||||
|
||||
def mock_tensor(data, device=None, **kwargs):
|
||||
if device == "npu":
|
||||
return metadata.attn_mask
|
||||
return torch.tensor(data, **kwargs)
|
||||
|
||||
with patch("torch.tensor", side_effect=mock_tensor):
|
||||
output = self.impl_192.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_vanilla_prefill.assert_called_once()
|
||||
assert output.shape == (10, 8 * 192)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention_splitfuse')
|
||||
def test_forward_normal_v1_situation(self, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in normal V1 situation"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention_splitfuse')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||
def test_forward_310p_device(self, mock_is_310p, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_npu_format_cast):
|
||||
"""Test forward pass on 310P device"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_raise_error(self, mock_paged_attention):
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.impl_error.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
232
tests/ut/quantization/test_quant_config.py
Normal file
232
tests/ut/quantization/test_quant_config.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
|
||||
AscendQuantConfig)
|
||||
|
||||
ASCEND_QUATIZATION_METHOD = "ascend"
|
||||
|
||||
|
||||
class TestAscendQuantConfig(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.sample_config = {
|
||||
"weight": "INT8",
|
||||
"fa_quant_type": "C8",
|
||||
"kv_quant_type": "C8",
|
||||
"layer1.weight": "INT8",
|
||||
"layer2.weight": "FLOAT",
|
||||
"fused_layer.weight": "FLOAT",
|
||||
"fused_layer.shard1.weight": "FLOAT",
|
||||
"fused_layer.shard2.weight": "FLOAT",
|
||||
"shard1.weight": "FLOAT",
|
||||
"shard2.weight": "FLOAT",
|
||||
}
|
||||
self.ascend_config = AscendQuantConfig(self.sample_config)
|
||||
self.ascend_config.packed_modules_mapping = None
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.ascend_config.quant_description,
|
||||
self.sample_config)
|
||||
|
||||
def test_repr(self):
|
||||
repr_str = repr(self.ascend_config)
|
||||
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendQuantConfig.get_name(),
|
||||
ASCEND_QUATIZATION_METHOD)
|
||||
|
||||
def test_get_supported_act_dtypes(self):
|
||||
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
|
||||
self.assertEqual(len(supported_dtypes), 3)
|
||||
|
||||
def test_get_min_capability(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
AscendQuantConfig.get_min_capability()
|
||||
|
||||
def test_get_config_filenames(self):
|
||||
filenames = AscendQuantConfig.get_config_filenames()
|
||||
self.assertEqual(filenames, ["quant_model_description.json"])
|
||||
|
||||
def test_from_config(self):
|
||||
config = AscendQuantConfig.from_config(self.sample_config)
|
||||
self.assertIsInstance(config, AscendQuantConfig)
|
||||
self.assertEqual(config.quant_description, self.sample_config)
|
||||
|
||||
@patch('torch.npu.is_available')
|
||||
def test_override_quantization_method(self, mock_is_available):
|
||||
# Test when NPU is available
|
||||
mock_is_available.return_value = True
|
||||
result = AscendQuantConfig.override_quantization_method(None, None)
|
||||
self.assertEqual(result, ASCEND_QUATIZATION_METHOD)
|
||||
|
||||
# Test when NPU is not available
|
||||
mock_is_available.return_value = False
|
||||
result = AscendQuantConfig.override_quantization_method(None, None)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_quant_method_for_linear(self):
|
||||
linear_layer = MagicMock(spec=LinearBase)
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config,
|
||||
'is_layer_skipped_ascend',
|
||||
return_value=True):
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
self.assertIsInstance(method, UnquantizedLinearMethod)
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
||||
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_linear.return_value)
|
||||
mock_ascend_linear.assert_called_once_with(
|
||||
self.ascend_config, ".attn",
|
||||
self.ascend_config.packed_modules_mapping)
|
||||
|
||||
def test_get_quant_method_for_attention(self):
|
||||
attention_layer = MagicMock(spec=Attention)
|
||||
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with fa_quant_type
|
||||
method = self.ascend_config.get_quant_method(
|
||||
attention_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_kvcache.return_value)
|
||||
|
||||
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with kv_quant_type
|
||||
modified_config = {"kv_quant_type": "C8"}
|
||||
config = AscendQuantConfig(modified_config)
|
||||
config.packed_modules_mapping = None
|
||||
method = config.get_quant_method(attention_layer, "attn")
|
||||
self.assertIs(method, mock_ascend_kvcache.return_value)
|
||||
|
||||
def test_get_quant_method_for_fused_moe(self):
|
||||
fused_moe_layer = MagicMock(spec=FusedMoE)
|
||||
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
self.assertIs(method, mock_ascend_moe.return_value)
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
self.assertIs(method, mock_ascend_moe.return_value)
|
||||
|
||||
def test_is_layer_skipped_ascend(self):
|
||||
# Test non-fused layer that should be quantized
|
||||
self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1"))
|
||||
|
||||
# Test non-fused layer that should be skipped
|
||||
self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2"))
|
||||
|
||||
# Test fused layer
|
||||
fused_mapping = {"fused_layer": ["shard1", "shard2"]}
|
||||
self.assertTrue(
|
||||
self.ascend_config.is_layer_skipped_ascend("fused_layer",
|
||||
fused_mapping))
|
||||
|
||||
# Test inconsistent fused layer shards
|
||||
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
|
||||
config = AscendQuantConfig(bad_config)
|
||||
with self.assertRaises(ValueError):
|
||||
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
|
||||
|
||||
def test_get_scaled_act_names(self):
|
||||
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
|
||||
|
||||
|
||||
class TestAscendKVCacheMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
self.prefix = "attention_layer"
|
||||
|
||||
# Mock the quantizer and quant_method
|
||||
self.mock_quantizer = MagicMock()
|
||||
self.mock_quant_method = MagicMock()
|
||||
|
||||
# Patch the AscendQuantizer
|
||||
self.quantizer_patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
|
||||
return_value=self.mock_quantizer)
|
||||
self.mock_get_quantizer = self.quantizer_patcher.start()
|
||||
|
||||
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
|
||||
|
||||
# Create instance
|
||||
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
||||
self.prefix)
|
||||
|
||||
def tearDown(self):
|
||||
self.quantizer_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization with proper quantizer setup."""
|
||||
self.mock_get_quantizer.assert_called_once_with(
|
||||
self.mock_quant_config.quant_description, self.prefix)
|
||||
self.mock_quantizer.build_attention_method.assert_called_once()
|
||||
|
||||
def test_create_weights(self):
|
||||
"""Test create_weights delegates to quant_method."""
|
||||
mock_layer = MagicMock()
|
||||
self.kv_cache_method.create_weights(mock_layer)
|
||||
self.mock_quant_method.create_weights.assert_called_once_with(
|
||||
mock_layer)
|
||||
|
||||
def test_process_weights_after_loading_with_method(self):
|
||||
"""Test process_weights when quant_method has the method."""
|
||||
mock_layer = MagicMock()
|
||||
self.kv_cache_method.process_weights_after_loading(mock_layer)
|
||||
self.mock_quant_method.process_weights_after_loading.assert_called_once_with(
|
||||
mock_layer)
|
||||
|
||||
def test_process_weights_after_loading_without_method(self):
|
||||
"""Test process_weights when quant_method lacks the method."""
|
||||
# Reset mock to remove the method
|
||||
del self.mock_quant_method.process_weights_after_loading
|
||||
mock_layer = MagicMock()
|
||||
|
||||
# Should not raise exception
|
||||
self.kv_cache_method.process_weights_after_loading(mock_layer)
|
||||
|
||||
def test_apply_delegation(self):
|
||||
"""Test apply properly delegates to quant_method."""
|
||||
mock_layer = MagicMock()
|
||||
mock_query = torch.randn(1, 32, 128)
|
||||
mock_key = torch.randn(1, 32, 128)
|
||||
mock_value = torch.randn(1, 32, 128)
|
||||
mock_kv_cache = MagicMock()
|
||||
mock_attn_metadata = MagicMock()
|
||||
mock_scale = 1.0
|
||||
mock_output = torch.zeros(1, 32, 128)
|
||||
mock_attn_type = MagicMock()
|
||||
expected_result = torch.randn(1, 32, 128)
|
||||
self.mock_quant_method.apply.return_value = expected_result
|
||||
|
||||
result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key,
|
||||
mock_value, mock_kv_cache,
|
||||
mock_attn_metadata, mock_attn_type,
|
||||
mock_scale, mock_output)
|
||||
|
||||
self.mock_quant_method.apply.assert_called_once_with(
|
||||
mock_layer, mock_query, mock_key, mock_value, mock_kv_cache,
|
||||
mock_attn_metadata, mock_attn_type, mock_scale, mock_output)
|
||||
self.assertTrue(torch.equal(result, expected_result))
|
||||
124
tests/ut/quantization/test_quantizer.py
Normal file
124
tests/ut/quantization/test_quantizer.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
||||
W8A8Quantizer)
|
||||
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
||||
|
||||
|
||||
class TestGetQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.supported_types = {
|
||||
'INT8': MagicMock(_instance=None),
|
||||
'FP16': MagicMock(_instance=None),
|
||||
'C8': MagicMock(_instance=None)
|
||||
}
|
||||
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original supported types
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
|
||||
|
||||
def test_get_quantizer_fa(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'fa_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_kv(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'kv_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_linear(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'linear_type': 'INT8'}
|
||||
prefix = 'nothing'
|
||||
expected_type = 'INT8'
|
||||
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
|
||||
return_value=expected_type), \
|
||||
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
|
||||
class TestW8A8Quantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W8A8Quantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_attention_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_attention_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
730
tests/ut/quantization/test_w8a8.py
Normal file
730
tests/ut/quantization/test_w8a8.py
Normal file
@@ -0,0 +1,730 @@
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
fused_experts, native_grouped_topk,
|
||||
quant_per_tensor, select_experts)
|
||||
|
||||
|
||||
class TestQuantPerTensor(TestBase):
|
||||
|
||||
@patch("torch_npu.npu_quantize")
|
||||
def test_quant_per_tensor(self, mock_npu_quantize):
|
||||
in_tensor = torch.randn(32, 128)
|
||||
input_scale = torch.tensor(0.1)
|
||||
input_offset = torch.tensor(0)
|
||||
|
||||
expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
mock_npu_quantize.return_value = expected_output
|
||||
|
||||
output = quant_per_tensor(in_tensor, input_scale, input_offset)
|
||||
|
||||
mock_npu_quantize.assert_called_once_with(
|
||||
in_tensor,
|
||||
input_scale,
|
||||
input_offset,
|
||||
torch.qint8,
|
||||
-1,
|
||||
False,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_output))
|
||||
|
||||
|
||||
class TestAscendW8A8LinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendW8A8LinearMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(10, 20)
|
||||
self.assertEqual(weight['weight'].dtype, torch.int8)
|
||||
self.assertEqual(weight['weight'].shape, (20, 10))
|
||||
|
||||
def test_get_pertensor_param(self):
|
||||
params = self.method.get_pertensor_param(torch.bfloat16)
|
||||
self.assertEqual(params['input_scale'].dtype, torch.bfloat16)
|
||||
self.assertEqual(params['input_offset'].dtype, torch.int8)
|
||||
self.assertEqual(params['input_scale'].shape, (1, ))
|
||||
self.assertEqual(params['input_offset'].shape, (1, ))
|
||||
|
||||
def test_get_perchannel_param(self):
|
||||
params = self.method.get_perchannel_param(10, torch.bfloat16)
|
||||
|
||||
self.assertEqual(params['quant_bias'].dtype, torch.int32)
|
||||
self.assertEqual(params['deq_scale'].dtype, torch.float32)
|
||||
self.assertEqual(params['weight_scale'].dtype, torch.bfloat16)
|
||||
self.assertEqual(params['weight_offset'].dtype, torch.bfloat16)
|
||||
self.assertEqual(params['quant_bias'].shape, (10, ))
|
||||
self.assertEqual(params['deq_scale'].shape, (10, ))
|
||||
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
||||
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
|
||||
mock_quant_per_tensor):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
layer.aclnn_input_offset = 0.2
|
||||
layer.weight = torch.randn(128, 256)
|
||||
layer.deq_scale = 0.3
|
||||
|
||||
x = torch.randn(32, 128)
|
||||
bias = torch.randn(256)
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
x.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_quant_matmul.return_value = expected_y_output
|
||||
|
||||
output = self.method.apply(layer, x, bias)
|
||||
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_is_int8(self, mock_npu_quant_matmul):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
layer.aclnn_input_offset = 0.2
|
||||
layer.weight = torch.randn(128, 256)
|
||||
layer.deq_scale = 0.3
|
||||
|
||||
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
bias = torch.randn(256)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_quant_matmul.return_value = expected_y_output
|
||||
|
||||
output = self.method.apply(layer, x, bias)
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading(self, mock_npu_format_cast):
|
||||
layer = MagicMock()
|
||||
|
||||
layer.weight.data = torch.randn(128, 256)
|
||||
layer.input_scale.data = torch.tensor([0.1])
|
||||
layer.input_offset.data = torch.tensor([0])
|
||||
layer.deq_scale = torch.tensor([0.5])
|
||||
layer.weight_scale.data = torch.randn(128, 1)
|
||||
layer.weight_offset.data = torch.randn(128, 1)
|
||||
|
||||
mock_npu_format_cast.return_value = MagicMock
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
|
||||
self.assertTrue(
|
||||
torch.equal(layer.aclnn_input_offset.data, expected_offset))
|
||||
self.assertFalse(layer.aclnn_input_offset.requires_grad)
|
||||
|
||||
self.assertFalse(layer.deq_scale.requires_grad)
|
||||
|
||||
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
||||
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.moe_method = AscendW8A8FusedMoEMethod()
|
||||
self.num_experts = 4
|
||||
self.intermediate_size = 64
|
||||
self.hidden_size = 128
|
||||
self.dtype = torch.float32
|
||||
|
||||
def test_init(self):
|
||||
self.assertTrue(self.moe_method.transpose_weight)
|
||||
|
||||
def test_get_weight(self):
|
||||
weights = self.moe_method.get_weight(
|
||||
num_experts=self.num_experts,
|
||||
intermediate_size_per_partition=self.intermediate_size,
|
||||
hidden_sizes=self.hidden_size,
|
||||
params_dtype=self.dtype)
|
||||
|
||||
assert "w13_weight" in weights, f"w13_weight not in {weights}"
|
||||
assert "w2_weight" in weights, f"w2_weight not in {weights}"
|
||||
self.assertEqual(
|
||||
weights["w13_weight"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
|
||||
self.assertEqual(
|
||||
weights["w2_weight"].shape,
|
||||
(self.num_experts, self.hidden_size, self.intermediate_size))
|
||||
self.assertEqual(weights["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(weights["w2_weight"].dtype, torch.int8)
|
||||
self.assertFalse(weights["w13_weight"].requires_grad)
|
||||
self.assertFalse(weights["w2_weight"].requires_grad)
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
quant_params = self.moe_method.get_dynamic_quant_param(
|
||||
num_experts=self.num_experts,
|
||||
intermediate_size_per_partition=self.intermediate_size,
|
||||
hidden_sizes=self.hidden_size,
|
||||
params_dtype=self.dtype)
|
||||
|
||||
expected_params = [
|
||||
"w13_weight_scale", "w13_weight_offset", "w2_weight_scale",
|
||||
"w2_weight_offset", "w2_deq_scale", "w13_deq_scale",
|
||||
"w2_input_scale", "w13_input_scale", "w2_input_offset",
|
||||
"w13_input_offset", "quant_bias"
|
||||
]
|
||||
|
||||
for param in expected_params:
|
||||
assert param in quant_params, f"{param} not in {quant_params}"
|
||||
|
||||
# Check some sample shapes
|
||||
self.assertEqual(quant_params["w13_weight_scale"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, 1))
|
||||
self.assertEqual(quant_params["w2_input_offset"].shape,
|
||||
(self.num_experts, 1))
|
||||
self.assertEqual(quant_params["quant_bias"].shape,
|
||||
(self.num_experts, self.hidden_size))
|
||||
|
||||
@patch('vllm_ascend.quantization.w8a8.select_experts')
|
||||
@patch('vllm_ascend.quantization.w8a8.fused_experts')
|
||||
def test_apply_with_other_expert_count(self, mock_fused_experts,
|
||||
mock_select_experts):
|
||||
# Setup
|
||||
mock_layer = MagicMock()
|
||||
x = torch.randn(32, self.hidden_size)
|
||||
router_logits = torch.randn(32, 128) # 128 experts
|
||||
top_k = 2
|
||||
|
||||
# Mock return values
|
||||
mock_select_experts.return_value = (torch.randn(32, top_k),
|
||||
torch.randint(0, 128, (32, top_k)))
|
||||
mock_fused_experts.return_value = torch.randn(32, self.hidden_size)
|
||||
|
||||
# Test
|
||||
result = self.moe_method.apply(layer=mock_layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
global_num_experts=128)
|
||||
|
||||
# Assertions
|
||||
mock_select_experts.assert_called_once()
|
||||
mock_fused_experts.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
|
||||
class TestAscendC8KVCacheMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.num_kv_heads = 4
|
||||
self.layer.head_size = 64
|
||||
self.layer.num_heads = 8
|
||||
self.layer._k_scale_float = 1.0
|
||||
self.layer._v_scale_float = 1.0
|
||||
self.method = AscendC8KVCacheMethod()
|
||||
|
||||
self.attention_type = MagicMock()
|
||||
self.attention_type.DECODER = "decoder"
|
||||
self.attention_type.ENCODER = "encoder"
|
||||
|
||||
def test_create_weights(self):
|
||||
"""测试 create_weights 是否正确注册参数"""
|
||||
AscendC8KVCacheMethod.create_weights(self.layer)
|
||||
|
||||
self.layer.register_parameter.assert_any_call("key_antiquant_scale",
|
||||
unittest.mock.ANY)
|
||||
self.layer.register_parameter.assert_any_call("value_antiquant_scale",
|
||||
unittest.mock.ANY)
|
||||
|
||||
calls = self.layer.register_parameter.call_args_list
|
||||
|
||||
for call in calls:
|
||||
args, kwargs = call
|
||||
param = kwargs.get('parameter', args[1] if len(args) > 1 else None)
|
||||
|
||||
expected_shape = (self.layer.num_kv_heads * self.layer.head_size, )
|
||||
self.assertEqual(param.shape, expected_shape)
|
||||
|
||||
def test_process_weights_after_loading(self):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
self.layer.key_antiquant_scale.data = key_data
|
||||
self.layer.value_antiquant_scale.data = value_data
|
||||
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
|
||||
|
||||
@patch('torch_npu.npu_scatter_nd_update_')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_decode_only(self, mock_quant, mock_scatter):
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
|
||||
attn_metadata.slot_mapping = torch.tensor([0, 1])
|
||||
attn_metadata.attn_mask = None
|
||||
|
||||
block_size = 16
|
||||
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
kv_cache = (key_cache, value_cache)
|
||||
|
||||
mock_quant.side_effect = [key, value]
|
||||
|
||||
self.layer.key_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.layer.value_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
expected_output = torch.randn(
|
||||
num_tokens, self.layer.num_heads * self.layer.head_size)
|
||||
with patch('torch_npu.npu_incre_flash_attention',
|
||||
return_value=expected_output):
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata,
|
||||
self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
self.assertEqual(mock_quant.call_count, 2)
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_apply_prefill_no_cache(self, mock_flash, mock_quant):
|
||||
"""Test apply method in prefill no-cache mode"""
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.attn_mask = torch.ones(2, 2)
|
||||
|
||||
kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
mock_quant.return_value = key
|
||||
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata, self.attention_type.DECODER,
|
||||
1.0, output)
|
||||
|
||||
# Check that flash attention was called
|
||||
mock_flash.assert_called_once()
|
||||
|
||||
# Check output shape
|
||||
self.assertEqual(
|
||||
result.shape,
|
||||
(num_tokens, self.layer.num_heads * self.layer.head_size))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_unsupported_attention_type(self, mock_quant):
|
||||
|
||||
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_quant.return_value = key
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
|
||||
with self.assertRaises(NotImplementedError) as cm:
|
||||
self.method.apply(self.layer, query, key, value, (None, None),
|
||||
attn_metadata, self.attention_type.ENCODER, 1.0,
|
||||
output)
|
||||
|
||||
assert "Encoder self-attention" in str(
|
||||
cm.exception), f"Encoder self-attention not in {str(cm.exception)}"
|
||||
assert "not implemented" in str(
|
||||
cm.exception), f"not implemented not in{str(cm.exception)}"
|
||||
|
||||
mock_quant.assert_not_called()
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_unsupported_attention_state(self, mock_quant):
|
||||
"""Test apply with unsupported attention state"""
|
||||
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
mock_quant.return_value = key
|
||||
kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata, self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
|
||||
class TestFusedExperts(TestBase):
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_moe_init_routing_v2')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_moe_finalize_routing')
|
||||
def test_fused_experts_with_expert_map(self, mock_finalize, mock_swiglu,
|
||||
mock_group_matmul,
|
||||
mock_init_routing,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor):
|
||||
num_tokens = 32
|
||||
hidden_size = 128
|
||||
intermediate_size = 256
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w1_scale = torch.tensor([0.1])
|
||||
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
|
||||
w1_input_offset = torch.tensor([0])
|
||||
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
w2_scale = torch.tensor([0.1])
|
||||
w2_input_scale = torch.tensor([0.2])
|
||||
w2_input_offset = torch.tensor([0])
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
expert_map = torch.arange(num_experts)
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 8
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
mock_init_routing.return_value = (torch.randn(num_tokens * top_k,
|
||||
hidden_size),
|
||||
torch.arange(num_tokens * top_k),
|
||||
torch.tensor([num_tokens // 2] * 2),
|
||||
torch.tensor(1.0))
|
||||
|
||||
mock_group_matmul.side_effect = [[
|
||||
torch.randn(num_tokens * top_k, intermediate_size * 2)
|
||||
], [torch.randn(num_tokens * top_k, hidden_size)]]
|
||||
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
|
||||
expected_output = torch.randn(num_tokens, hidden_size)
|
||||
mock_finalize.return_value = expected_output
|
||||
|
||||
output = fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w1_input_scale=w1_input_scale,
|
||||
w1_input_offset=w1_input_offset,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
w2_input_offset=w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
mock_init_routing.assert_called_once()
|
||||
|
||||
self.assertEqual(mock_group_matmul.call_count, 2)
|
||||
|
||||
self.assertEqual(output.shape, (num_tokens, hidden_size))
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
def test_fused_experts_without_expert_map(self, mock_swiglu,
|
||||
mock_group_matmul,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor):
|
||||
num_tokens = 16
|
||||
hidden_size = 64
|
||||
intermediate_size = 128
|
||||
num_experts = 8
|
||||
top_k = 1
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 8
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
mock_group_matmul.side_effect = [[
|
||||
torch.randn(num_tokens * top_k, intermediate_size * 2)
|
||||
], [torch.randn(num_tokens * top_k, hidden_size)]]
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=torch.tensor([0.1]),
|
||||
w1_input_scale=torch.tensor([[0.2, 0.2], [0.2, 0.2]]),
|
||||
w1_input_offset=torch.tensor([0]),
|
||||
w2=w2,
|
||||
w2_scale=torch.tensor([0.1]),
|
||||
w2_input_scale=torch.tensor([0.1]),
|
||||
w2_input_offset=torch.tensor([0]),
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
|
||||
class TestSelectExperts(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common test data
|
||||
self.num_tokens = 10
|
||||
self.hidden_size = 32
|
||||
self.num_experts = 8
|
||||
self.top_k = 2
|
||||
|
||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||
|
||||
def test_softmax_scoring(self):
|
||||
"""Test softmax scoring function"""
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="softmax")
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_sigmoid_scoring(self):
|
||||
"""Test sigmoid scoring function"""
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid")
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_invalid_scoring_func(self):
|
||||
"""Test invalid scoring function raises ValueError"""
|
||||
with self.assertRaises(ValueError):
|
||||
select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid_func")
|
||||
|
||||
@patch('torch.topk')
|
||||
def test_grouped_topk(self, mock_topk):
|
||||
"""Test grouped topk functionality"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2)
|
||||
|
||||
mock_topk.assert_called()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
self.num_experts)
|
||||
|
||||
e_score_correction_bias = torch.randn(self.num_experts)
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
mock_grouped_topk.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_custom_routing_function(self):
|
||||
"""Test custom routing function"""
|
||||
mock_custom_routing = MagicMock()
|
||||
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
|
||||
self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
custom_routing_function=mock_custom_routing)
|
||||
|
||||
mock_custom_routing.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('torch.topk')
|
||||
def test_renormalize(self, mock_topk):
|
||||
"""Test weight renormalization"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
# Check if weights are normalized (sum to 1 for each token)
|
||||
sums = weights.sum(dim=-1)
|
||||
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
|
||||
|
||||
@patch('torch.topk')
|
||||
def test_output_dtypes(self, mock_topk):
|
||||
"""Test output dtypes"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
self.assertEqual(weights.dtype, self.hidden_states.dtype)
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
|
||||
class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
|
||||
def test_basic_group_selection(self):
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
|
||||
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1],
|
||||
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
|
||||
[0.9, 0.1, 0.8, 0.2, 0.7, 0.3, 0.6, 0.4]],
|
||||
dtype=torch.float32)
|
||||
|
||||
expected_topk_indices = torch.tensor([[0, 1], [1, 0], [0, 1], [0, 1]])
|
||||
|
||||
with patch('torch.topk',
|
||||
return_value=(None, expected_topk_indices)) as mock_topk:
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=2)
|
||||
|
||||
mock_topk.assert_called_once()
|
||||
|
||||
expected_result = topk_weights
|
||||
self.assertTrue(torch.allclose(result, expected_result))
|
||||
|
||||
def test_partial_group_selection(self):
|
||||
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
|
||||
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1]])
|
||||
|
||||
expected_topk_indices = torch.tensor([[0], [1]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=1)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.9, 0.1]])
|
||||
self.assertTrue(torch.allclose(result, expected_result))
|
||||
|
||||
def test_single_group(self):
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2], [0.8, 0.3, 0.7]])
|
||||
|
||||
expected_topk_indices = torch.tensor([[0], [0]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=1,
|
||||
topk_group=1)
|
||||
self.assertTrue(result.numel() > 0)
|
||||
@@ -274,6 +274,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_int8 = kv_cache.numel(
|
||||
) > 0 and kv_cache[0].dtype == torch.int8
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
@@ -289,7 +291,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
|
||||
elif hasattr(layer, 'quant_method'):
|
||||
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
@@ -429,7 +431,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method'):
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
@@ -251,7 +251,8 @@ class VLLMAscendQuantizer:
|
||||
# Attention
|
||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['fa_quant_type']
|
||||
if '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
# Use KVCache int8
|
||||
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['kv_quant_type']
|
||||
# Linear
|
||||
else:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
@@ -219,53 +218,34 @@ class AscendW8A8FusedMoEMethod:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group,
|
||||
group_count=num_expert_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
||||
raise NotImplementedError("W8A8FusedMoe are not "
|
||||
"mplemented for VLLM_ENABLE_MC2")
|
||||
|
||||
else:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w1_input_offset=layer.w13_input_offset,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
w2_input_offset=layer.w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w1_input_offset=layer.w13_input_offset,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
w2_input_offset=layer.w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
# torch.npu.config.allow_internal_format = True
|
||||
@@ -299,8 +279,10 @@ class AscendW8A8FusedMoEMethod:
|
||||
torch.int8)
|
||||
|
||||
# NZ
|
||||
# layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, 29).contiguous()
|
||||
# layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, 29).contiguous()
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, 29).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, 29).contiguous()
|
||||
|
||||
|
||||
class AscendC8KVCacheMethod:
|
||||
|
||||
@@ -2194,7 +2194,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
|
||||
Reference in New Issue
Block a user