Files
xc-llm-ascend/tests/ut/attention/test_attention_v1.py
Ronald c980e68d40 [Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
2026-03-13 09:11:46 +08:00

351 lines
14 KiB
Python

from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionBackendImpl,
AscendAttentionMetadataBuilder,
AscendAttentionState)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.utils import AscendDeviceType
class TestAscendAttentionBackend(TestBase):
def setUp(self):
self.mock_config = MagicMock()
mock_parallel_config = MagicMock()
mock_parallel_config.prefill_context_parallel_size = 1
mock_parallel_config.decode_context_parallel_size = 1
self.mock_config.parallel_config = mock_parallel_config
self.utils_patcher = patch(
'vllm_ascend.attention.utils.get_current_vllm_config',
return_value=self.mock_config)
self.utils_patcher.start()
from vllm_ascend.attention.utils import enable_cp
enable_cp.cache_clear()
def test_get_name(self):
self.assertEqual(AscendAttentionBackend.get_name(), "CUSTOM")
def test_get_impl_cls(self):
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
AscendAttentionBackendImpl)
def test_get_builder_cls(self):
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
AscendAttentionMetadataBuilder)
def test_get_kv_cache_shape_not(self):
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 20, 30, 40))
def test_swap_blocks(self):
src_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
dst_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
src_to_dst = torch.tensor([[0, 1], [2, 3]])
AscendAttentionBackend.swap_blocks(src_kv_cache, dst_kv_cache,
src_to_dst)
self.assertTrue(torch.all(dst_kv_cache[0][1] == src_kv_cache[0][0]))
self.assertTrue(torch.all(dst_kv_cache[1][3] == src_kv_cache[1][2]))
def test_copy_blocks(self):
kv_caches = [torch.zeros((10, 20)), torch.zeros((10, 20))]
src_to_dists = torch.tensor([[0, 1], [2, 3]])
AscendAttentionBackend.copy_blocks(kv_caches, src_to_dists)
self.assertTrue(torch.all(kv_caches[0][1] == kv_caches[0][0]))
self.assertTrue(torch.all(kv_caches[1][3] == kv_caches[1][2]))
class TestAscendAttentionMetadataBuilder(TestBase):
def setUp(self):
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.model_config.hf_text_config.sliding_window = None
self.mock_vllm_config.cache_config.block_size = 64
self.mock_vllm_config.compilation_config.cudagraph_mode = None
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
self.mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
self.mock_device = 'cpu:0'
torch.Tensor.pin_memory = lambda x: x # noqa
self.builder = AscendAttentionMetadataBuilder(None, None,
self.mock_vllm_config,
self.mock_device)
def test_reorder_batch(self):
mock_input_batch = MagicMock()
mock_scheduler_output = MagicMock()
result = self.builder.reorder_batch(mock_input_batch,
mock_scheduler_output)
self.assertFalse(result)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
def test_build(self, mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None,
max_seq_len=6)
mock_model = MagicMock()
self.builder.build(1, common_attn_metadata, mock_model)
class TestAscendAttentionBackendImpl(TestBase):
def setUp(self):
self.mock_event = MagicMock()
self.mock_event.record.return_value = None
self.mock_event.wait.return_value = None
self.mock_stream = MagicMock()
self.event_patcher = patch('torch_npu.npu.Event',
return_value=self.mock_event)
self.stream_patcher = patch('torch_npu.npu.current_stream',
return_value=self.mock_stream)
self.event_patcher.start()
self.stream_patcher.start()
self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
self.layer._v_scale_float = 1.0
self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"
self.attn_metadata = MagicMock()
self.attn_metadata.return_value = "1"
self.layer_no_quant = MagicMock(
spec=['layer_name', '_k_scale_float', '_v_scale_float'])
self.layer_no_quant.layer_name = "test_layer"
self.layer_no_quant._k_scale_float = 1.0
self.layer_no_quant._v_scale_float = 1.0
self.mock_vllm_config = MagicMock()
self.config_patcher = patch(
'vllm_ascend.attention.attention_v1.get_current_vllm_config',
return_value=self.mock_vllm_config)
self.config_patcher.start()
self.impl = AscendAttentionBackendImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
self.impl_192 = AscendAttentionBackendImpl(
num_heads=8,
head_size=192,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
self.impl_error = AscendAttentionBackendImpl(
num_heads=8,
head_size=192,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None)
self.impl_swa = AscendAttentionBackendImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=1024,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
def test_forward_no_attn_metadata(self):
"""Test forward pass when attn_metadata is None"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 0, 0, 8, 64)
layer = self.layer_no_quant
output = torch.empty_like(query)
output = self.impl.forward(layer, query, key, value, kv_cache, None,
output)
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_forward_fused_infer_attention(
self, mock_get_forward_context,
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.actual_seq_lengths_q = [10]
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.num_decode_tokens = 0
metadata.num_decodes = 0
metadata.num_prefills = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (torch.ones(
10, 8, 64), torch.ones(10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_npu_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8, 64)
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_forward_paged_attention(self, mock_get_forward_context,
mock_npu_reshape_and_cache,
mock_paged_attention,
mock_using_paged_attention):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(4, 8 * 64)
key = torch.randn(4, 8 * 64)
value = torch.randn(4, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([4])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 4
metadata.slot_mapping = torch.zeros(4, dtype=torch.long)
metadata.num_decodes = 4
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_using_paged_attention.return_value = True
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
assert output.shape == (4, 8 * 64)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
mock_fused_infer_attention_score,
mock_get_forward_context):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty(10, 8, 64)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10] * 10)
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 100
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata, output)
print(output.shape)
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8, 64)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_decode_only_swa_seq_len_mismatch(
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
mock_paged_attention, mock_get_forward_context):
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10]) # len == 1 != query.size(0)==10
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
metadata.num_decodes = 10
metadata.num_prefills = 0
metadata.actual_seq_lengths_q = [10]
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64),
torch.ones(10, 8, 64))
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_not_called()
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8, 64)