Files
xc-llm-ascend/tests/ut/attention/test_sfa_v1.py
Qiu cb7c419bc0 [Feat](sfa,dcp) support dcp for sfa (#6563)
### What this PR does / why we need it?
This PR adds DCP support to the SFA backend.

Please note that due to operator constraints, the current implementation
has to all-gather the entire KV cache and modify the block table to
satisfy the operator input requirements. This results in significantly
increased communication overhead and peak memory usage. Therefore, this
is only a temporary workaround and will be refactored once the operator
provides proper support.

Additionally, because of the above limitations,
`cp_kv_cache_interleave_size` is currently required to be equal to
`block_size`. This restriction will also be removed after the refactor.

#### Test
accuracy test using DeepSeek-V3.2-Exp-W8A8 with dp2tp8dcp8

| dataset | version | metric | mode | vllm-api-general-stream |
|----- | ----- | ----- | ----- | -----|
| gsm8kdataset | - | accuracy | gen | 96.35 |

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
2026-02-09 18:52:25 +08:00

282 lines
12 KiB
Python

import sys
from unittest.mock import MagicMock, patch
import torch
from tests.ut.attention.utils import patch_distributed_groups
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm.distributed.parallel_state import GroupCoordinator
if 'torch_npu._inductor' not in sys.modules:
sys.modules['torch_npu._inductor'] = MagicMock()
from vllm_ascend.attention.sfa_v1 import (AscendSFABackend, AscendSFAImpl,
AscendSFAMetadata,
AscendSFAMetadataBuilder)
from vllm_ascend.utils import enable_dsa_cp, vllm_version_is
class TestAscendSFABackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendSFABackend.get_name(), "ASCEND_SFA")
def test_get_builder_cls(self):
self.assertEqual(AscendSFABackend.get_builder_cls(),
AscendSFAMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendSFABackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendSFABackend.get_impl_cls()
self.assertEqual(result, AscendSFAImpl)
class TestAscendSFAMetadata(TestBase):
def test_ascend_sfa_metadata_default(self):
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
seq_lens = torch.tensor([30, 50])
cum_query_lens = torch.tensor([0, 30, 80])
block_table = torch.randint(0, 100, (100, 4))
rope_dim = 32
max_seq_len = int(seq_lens.max().item())
sin = torch.randn(max_seq_len, rope_dim)
cos = torch.randn(max_seq_len, rope_dim)
num_input_tokens = 2
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
metadata = AscendSFAMetadata(
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
cum_query_lens=cum_query_lens,
block_table=block_table,
sin=sin,
cos=cos,
num_input_tokens=num_input_tokens,
head_dim=head_dim,
attn_mask=attn_mask,
attn_state=attn_state,
)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens))
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.sin, sin)
self.assertIs(metadata.cos, cos)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertIs(metadata.head_dim, head_dim)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
class TestAscendSFAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def setUp(self, mock_tp):
mock_tp.world_size = 2
mock_tp.rank_in_group = MagicMock()
mock_tp.device_group = MagicMock()
self.mock_cfg = MagicMock()
self.mock_cfg.parallel_config = MagicMock()
self.mock_cfg.parallel_config.tensor_parallel_size = 1
self.mock_cfg.parallel_config.prefill_context_parallel_size = 1
self.mock_cfg.parallel_config.decode_context_parallel_size = 1
self.mock_cfg.compilation_config = MagicMock()
self.mock_cfg.compilation_config.pass_config = MagicMock()
self.mock_cfg.compilation_config.pass_config.enable_sp = False
self.mock_cfg.speculative_config.num_speculative_tokens = 0
self.patcher = patch("vllm.config.get_current_vllm_config",
return_value=self.mock_cfg)
self.patcher.start()
# Mock parent class __init__ to avoid complex initialization,
# but still set the essential attributes that child class needs
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
device, metadata_cls, supports_dcp_with_varlen):
self.metadata_cls = metadata_cls
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.device = device
self.chunked_prefill_workspace_size = 128 * 1024
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
vllm_config.model_config.get_head_size()),
dtype=vllm_config.model_config.dtype,
device=device,
)
self.parent_init_patcher = patch(
"vllm.model_executor.layers.attention.mla_attention.MLACommonMetadataBuilder.__init__",
mock_parent_init)
self.parent_init_patcher.start()
if hasattr(enable_dsa_cp, "cache_clear"):
enable_dsa_cp.cache_clear()
def tearDown(self):
self.patcher.stop()
self.parent_init_patcher.stop()
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_ascend_sfa_metadata_builder_default(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
assert builder.device == device
assert builder.vllm_config == vllm_config
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_ascend_sfa_metadata_builder_build(
self,
mock_enable_dsa_cp,
mock_get_cos_and_sin_mla,
mock_get_current_vllm_config,
):
mock_enable_dsa_cp.return_value = False
cfg = MagicMock()
cfg.model_config = MagicMock()
cfg.model_config.hf_text_config = MagicMock()
mock_get_current_vllm_config.return_value = cfg
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
common_attn_metadata = MagicMock()
common_attn_metadata.num_reqs = 10
common_attn_metadata.num_actual_tokens = 100
common_attn_metadata.query_start_loc = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
common_attn_metadata.positions = torch.randn(100)
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
common_attn_metadata.cos = None
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))
metadata = builder.build(
common_prefix_len=10,
common_attn_metadata=common_attn_metadata,
)
assert isinstance(metadata, AscendSFAMetadata)
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
cfg = MagicMock()
cfg.model_config = MagicMock()
cfg.model_config.hf_text_config = MagicMock()
mock_get_current_vllm_config.return_value = cfg
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
common_attn_metadata = MagicMock()
common_attn_metadata.num_reqs = 10
common_attn_metadata.num_actual_tokens = 100
common_attn_metadata.query_start_loc = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
common_attn_metadata.positions = torch.randn(100)
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
common_attn_metadata.cos = None
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))
attn_metadata = builder.build_for_graph_capture(
common_attn_metadata=common_attn_metadata,
attn_state=AscendAttentionState.DecodeOnly,
)
assert isinstance(attn_metadata, AscendSFAMetadata)
assert attn_metadata.attn_state == AscendAttentionState.DecodeOnly