[Feat][UT] Support Deepseekv32 FULL_DECODE_ONLY mode and add unit test of sfa_v1 (#3763)

### What this PR does / why we need it?
- Add support for DeepSeek v3.2 in FULL_DECODE_ONLY mode.
- Add unit test for sfa_v1.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: 1Fire4 <wangdingyi2@huawei.com>
This commit is contained in:
1Fire4
2025-11-03 10:02:47 +08:00
committed by GitHub
parent d4c75088a0
commit 0b9b6d79fe
3 changed files with 216 additions and 6 deletions

View File

@@ -0,0 +1,185 @@
from unittest.mock import MagicMock
import torch
from vllm.v1.attention.backends.utils import AttentionCGSupport
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.sfa_v1 import (AscendSFABackend, AscendSFAImpl,
AscendSFAMetadata,
AscendSFAMetadataBuilder)
class TestAscendSFABackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendSFABackend.get_name(), "ASCEND_SFA")
def test_get_metadata_cls(self):
self.assertEqual(AscendSFABackend.get_metadata_cls(),
AscendSFAMetadata)
def test_get_builder_cls(self):
self.assertEqual(AscendSFABackend.get_builder_cls(),
AscendSFAMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendSFABackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendSFABackend.get_impl_cls()
self.assertEqual(result, AscendSFAImpl)
class TestAscendSFAMetadata(TestBase):
def test_ascend_sfa_metadata_default(self):
has_prefill = True
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
seq_lens = torch.tensor([30, 50])
cum_query_lens = torch.tensor([0, 30, 80])
block_tables = torch.randint(0, 100, (100, 4))
rope_dim = 32
max_seq_len = int(seq_lens.max().item())
sin = torch.randn(max_seq_len, rope_dim)
cos = torch.randn(max_seq_len, rope_dim)
num_input_tokens = 2
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
metadata = AscendSFAMetadata(
has_prefill=has_prefill,
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
cum_query_lens=cum_query_lens,
block_tables=block_tables,
sin=sin,
cos=cos,
num_input_tokens=num_input_tokens,
head_dim=head_dim,
attn_mask=attn_mask,
attn_state=attn_state,
)
self.assertEqual(metadata.has_prefill, has_prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens))
self.assertIs(metadata.block_tables, block_tables)
self.assertIs(metadata.sin, sin)
self.assertIs(metadata.cos, cos)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertIs(metadata.head_dim, head_dim)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
class TestAscendSFAMetadataBuilder(TestBase):
def test_ascend_sfa_metadata_builder_default(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
assert builder.aclgraph_support == AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
assert builder.device == device
assert builder.vllm_config == vllm_config
def test_ascend_sfa_metadata_builder_build(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
common_attn_metadata = MagicMock()
common_attn_metadata.num_reqs = 10
common_attn_metadata.num_actual_tokens = 100
common_attn_metadata.query_start_loc = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
common_attn_metadata.positions = torch.randn(100)
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0
metadata = builder.build(
common_prefix_len=10,
common_attn_metadata=common_attn_metadata,
model=model,
)
assert isinstance(metadata, AscendSFAMetadata)
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)
def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)
common_attn_metadata = MagicMock()
common_attn_metadata.num_reqs = 10
common_attn_metadata.num_actual_tokens = 100
common_attn_metadata.query_start_loc = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
common_attn_metadata.positions = torch.randn(100)
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0
attn_metadata = builder.build_for_graph_capture(
common_attn_metadata=common_attn_metadata,
attn_state=AscendAttentionState.DecodeOnly,
model=model,
)
assert isinstance(attn_metadata, AscendSFAMetadata)
assert attn_metadata.attn_state == AscendAttentionState.DecodeOnly

View File

@@ -91,7 +91,7 @@ M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@@ -189,6 +189,26 @@ class AscendSFAMetadataBuilder:
sin=sin,
cos=cos)
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
model=model,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)
attn_metadata.attn_state = attn_state
return attn_metadata
class AscendSFAImpl(MLAAttentionImpl):
"""

View File

@@ -1894,7 +1894,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
and not self.use_sparse:
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
@@ -2687,11 +2688,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
[0] * dcp_world_size for _ in range(pcp_world_size)
] for _ in range(num_tokens)]
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor(
if self.speculative_config:
query_start_loc = torch.tensor(
[0] + self.actual_seq_lengths_q[:num_reqs],
device=self.device,
dtype=torch.int32),
dtype=torch.int32)
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu,
@@ -2737,7 +2742,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
forward_context = get_forward_context()
assert forward_context is not None
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing:
not forward_context.capturing and not self.use_sparse:
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
if self.pcp_size * self.dcp_size > 1: