Files
xc-llm-ascend/tests/ut/_310p/attention/test_attention_v1_310.py
Ronald c980e68d40 [Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
2026-03-13 09:11:46 +08:00

211 lines
8.8 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend._310p.attention.attention_v1 import (
AscendAttentionBackend310,
AscendAttentionBackendImpl310,
AscendAttentionMetadataBuilder310,
AscendAttentionState,
)
class TestAscendAttentionBackend310(TestBase):
def setUp(self):
self.mock_config = MagicMock()
self.utils_patcher = patch("vllm_ascend.attention.utils.get_current_vllm_config", return_value=self.mock_config)
self.utils_patcher.start()
def test_get_impl_cls(self):
self.assertEqual(AscendAttentionBackend310.get_impl_cls(), AscendAttentionBackendImpl310)
def test_get_builder_cls(self):
self.assertEqual(AscendAttentionBackend310.get_builder_cls(), AscendAttentionMetadataBuilder310)
def test_get_kv_cache_shape_not(self):
result = AscendAttentionBackend310.get_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 75, 20, 16))
class TestAscendAttentionBackendImpl310(TestBase):
def setUp(self):
self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"
self.attn_metadata = MagicMock()
self.attn_metadata.return_value = "1"
self.mock_vllm_config = MagicMock()
self.layer_no_quant = MagicMock(spec=["layer_name", "_k_scale_float", "_v_scale_float"])
self.layer_no_quant.layer_name = "test_layer"
self.layer_no_quant._k_scale_float = 1.0
self.layer_no_quant._v_scale_float = 1.0
self.config_patcher = patch(
"vllm_ascend.attention.attention_v1.get_current_vllm_config", return_value=self.mock_vllm_config
)
self.config_patcher.start()
self.impl = AscendAttentionBackendImpl310(
num_heads=8,
head_size=128,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None,
)
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_flash_attention")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_prefill_310(
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
):
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.actual_seq_lengths_q = [10]
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.num_decode_tokens = 0
metadata.num_decodes = 0
metadata.num_prefills = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_npu_flash_attention.assert_called_once()
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_chunked_prefill_310(
self,
mock_get_forward_context,
mock_npu_paged_attention_splitfuse,
mock_npu_reshape_and_cache,
mock_format_cast,
):
"""Test forward pass in ChunkedPrefill state"""
query = torch.randn(5, 8, 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.ChunkedPrefill
metadata.attn_mask = torch.randn(1, 128, 16, 16)
metadata.query_lens = torch.tensor([5])
metadata.seq_lens = torch.tensor([1, 4])
metadata.query_start_loc = torch.tensor([0, 1, 5])
metadata.actual_seq_lengths_q = [5]
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.num_decode_tokens = 0
metadata.num_decodes = 0
metadata.num_prefills = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_paged_attention_splitfuse.assert_called_once()
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_prefill_cache_hit_310(
self,
mock_get_forward_context,
mock_npu_paged_attention_splitfuse,
mock_npu_reshape_and_cache,
mock_format_cast,
):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(5, 8, 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 128, 16, 16)
metadata.query_lens = torch.tensor([5])
metadata.seq_lens = torch.tensor([1, 4])
metadata.query_start_loc = torch.tensor([0, 1, 5])
metadata.actual_seq_lengths_q = [5]
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.num_decode_tokens = 0
metadata.num_decodes = 0
metadata.num_prefills = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_paged_attention_splitfuse.assert_called_once()
@patch("vllm_ascend.attention.attention_v1.using_paged_attention")
@patch("torch_npu._npu_paged_attention")
@patch("torch_npu._npu_reshape_and_cache")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_paged_attention_310(
self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention
):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(4, 8 * 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([4])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 4
metadata.slot_mapping = torch.zeros(4, dtype=torch.long)
metadata.num_decodes = 4
metadata.num_prefills = 0
mock_using_paged_attention.return_value = True
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_paged_attention.assert_called_once()
def test_forward_mtp_310(self):
query = torch.randn(4, 8 * 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.SpecDecoding
with self.assertRaises(NotImplementedError):
output = self.impl.forward_impl(query, key, value, None, metadata, output)