Files
xc-llm-ascend/tests/ut/attention/test_sfa_v1.py
zhangxinyuehfad f7b904641e [Main2Main] Upgrade vllm commit to 0109 (#5752)
### What this PR does / why we need it?
Upgrade vllm commit to 0109 (bde38c11df0ea066a740efe9b77fff5418be45df)

1. remove `init_cached_hf_modules ` due to
https://github.com/vllm-project/vllm/pull/31786
2. fix spec_decode e2e test due to
https://github.com/vllm-project/vllm/pull/29821 break
3. fix `vllm.v1.attention.backends.utils` duo to
https://github.com/vllm-project/vllm/pull/31891
4. fix `self.seq_lens - query_lens` on same device due to
https://github.com/vllm-project/vllm/pull/31773
5. skip model_runner_v2 e2e test due to `'_OpNamespace' '_C' object has
no attribute 'get_cuda_view_from_cpu_tensor'`

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
2026-01-13 19:14:43 +08:00

233 lines
9.2 KiB
Python

import sys
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
if 'torch_npu._inductor' not in sys.modules:
sys.modules['torch_npu._inductor'] = MagicMock()
from vllm_ascend.attention.sfa_v1 import (AscendSFABackend, AscendSFAImpl,
AscendSFAMetadata,
AscendSFAMetadataBuilder)
from vllm_ascend.utils import enable_dsa_cp
class TestAscendSFABackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendSFABackend.get_name(), "ASCEND_SFA")
def test_get_builder_cls(self):
self.assertEqual(AscendSFABackend.get_builder_cls(),
AscendSFAMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendSFABackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendSFABackend.get_impl_cls()
self.assertEqual(result, AscendSFAImpl)
class TestAscendSFAMetadata(TestBase):
def test_ascend_sfa_metadata_default(self):
has_prefill = True
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
seq_lens = torch.tensor([30, 50])
cum_query_lens = torch.tensor([0, 30, 80])
block_tables = torch.randint(0, 100, (100, 4))
rope_dim = 32
max_seq_len = int(seq_lens.max().item())
sin = torch.randn(max_seq_len, rope_dim)
cos = torch.randn(max_seq_len, rope_dim)
num_input_tokens = 2
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
metadata = AscendSFAMetadata(
has_prefill=has_prefill,
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
cum_query_lens=cum_query_lens,
block_tables=block_tables,
sin=sin,
cos=cos,
num_input_tokens=num_input_tokens,
head_dim=head_dim,
attn_mask=attn_mask,
attn_state=attn_state,
)
self.assertEqual(metadata.has_prefill, has_prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens))
self.assertIs(metadata.block_tables, block_tables)
self.assertIs(metadata.sin, sin)
self.assertIs(metadata.cos, cos)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertIs(metadata.head_dim, head_dim)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
class TestAscendSFAMetadataBuilder(TestBase):
def setUp(self):
self.mock_cfg = MagicMock()
self.mock_cfg.parallel_config = MagicMock()
self.mock_cfg.parallel_config.tensor_parallel_size = 1
self.mock_cfg.parallel_config.prefill_context_parallel_size = 1
self.mock_cfg.parallel_config.decode_context_parallel_size = 1
self.mock_cfg.compilation_config = MagicMock()
self.mock_cfg.compilation_config.pass_config = MagicMock()
self.mock_cfg.compilation_config.pass_config.enable_sp = False
self.mock_cfg.speculative_config.num_speculative_tokens = 0
self.patcher = patch("vllm.config.get_current_vllm_config",
return_value=self.mock_cfg)
self.patcher.start()
if hasattr(enable_dsa_cp, "cache_clear"):
enable_dsa_cp.cache_clear()
def test_ascend_sfa_metadata_builder_default(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
assert builder.device == device
assert builder.vllm_config == vllm_config
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
def test_ascend_sfa_metadata_builder_build(
self,
mock_enable_dsa_cp,
mock_get_cos_and_sin_mla,
mock_get_current_vllm_config,
):
mock_enable_dsa_cp.return_value = False
cfg = MagicMock()
cfg.model_config = MagicMock()
cfg.model_config.hf_text_config = MagicMock()
mock_get_current_vllm_config.return_value = cfg
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
common_attn_metadata = MagicMock()
common_attn_metadata.num_reqs = 10
common_attn_metadata.num_actual_tokens = 100
common_attn_metadata.query_start_loc = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
common_attn_metadata.positions = torch.randn(100)
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
common_attn_metadata.cos = None
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))
metadata = builder.build(
common_prefix_len=10,
common_attn_metadata=common_attn_metadata,
)
assert isinstance(metadata, AscendSFAMetadata)
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
cfg = MagicMock()
cfg.model_config = MagicMock()
cfg.model_config.hf_text_config = MagicMock()
mock_get_current_vllm_config.return_value = cfg
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
common_attn_metadata = MagicMock()
common_attn_metadata.num_reqs = 10
common_attn_metadata.num_actual_tokens = 100
common_attn_metadata.query_start_loc = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
common_attn_metadata.positions = torch.randn(100)
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
common_attn_metadata.cos = None
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))
attn_metadata = builder.build_for_graph_capture(
common_attn_metadata=common_attn_metadata,
attn_state=AscendAttentionState.DecodeOnly,
)
assert isinstance(attn_metadata, AscendSFAMetadata)
assert attn_metadata.attn_state == AscendAttentionState.DecodeOnly