[Feat][UT] Support Deepseekv32 FULL_DECODE_ONLY mode and add unit test of sfa_v1 (#3763)
### What this PR does / why we need it?
- Add support for DeepSeek v3.2 in FULL_DECODE_ONLY mode.
- Add unit test for sfa_v1.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: 1Fire4 <wangdingyi2@huawei.com>
This commit is contained in:
185
tests/ut/attention/test_sfa_v1.py
Normal file
185
tests/ut/attention/test_sfa_v1.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
from vllm_ascend.attention.sfa_v1 import (AscendSFABackend, AscendSFAImpl,
|
||||||
|
AscendSFAMetadata,
|
||||||
|
AscendSFAMetadataBuilder)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendSFABackend(TestBase):
|
||||||
|
|
||||||
|
def test_get_name(self):
|
||||||
|
self.assertEqual(AscendSFABackend.get_name(), "ASCEND_SFA")
|
||||||
|
|
||||||
|
def test_get_metadata_cls(self):
|
||||||
|
self.assertEqual(AscendSFABackend.get_metadata_cls(),
|
||||||
|
AscendSFAMetadata)
|
||||||
|
|
||||||
|
def test_get_builder_cls(self):
|
||||||
|
self.assertEqual(AscendSFABackend.get_builder_cls(),
|
||||||
|
AscendSFAMetadataBuilder)
|
||||||
|
|
||||||
|
def test_get_kv_cache_shape(self):
|
||||||
|
result = AscendSFABackend.get_kv_cache_shape(2, 4, 8, 128)
|
||||||
|
self.assertEqual(result, (2, 4, 8, 128))
|
||||||
|
|
||||||
|
def test_get_impl_cls(self):
|
||||||
|
result = AscendSFABackend.get_impl_cls()
|
||||||
|
self.assertEqual(result, AscendSFAImpl)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendSFAMetadata(TestBase):
|
||||||
|
|
||||||
|
def test_ascend_sfa_metadata_default(self):
|
||||||
|
has_prefill = True
|
||||||
|
num_actual_tokens = 100
|
||||||
|
slot_mapping = torch.randn(100, 4, 1024)
|
||||||
|
seq_lens = torch.tensor([30, 50])
|
||||||
|
cum_query_lens = torch.tensor([0, 30, 80])
|
||||||
|
block_tables = torch.randint(0, 100, (100, 4))
|
||||||
|
|
||||||
|
rope_dim = 32
|
||||||
|
max_seq_len = int(seq_lens.max().item())
|
||||||
|
sin = torch.randn(max_seq_len, rope_dim)
|
||||||
|
cos = torch.randn(max_seq_len, rope_dim)
|
||||||
|
|
||||||
|
num_input_tokens = 2
|
||||||
|
head_dim = None
|
||||||
|
attn_mask = None
|
||||||
|
attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
|
||||||
|
metadata = AscendSFAMetadata(
|
||||||
|
has_prefill=has_prefill,
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
cum_query_lens=cum_query_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
sin=sin,
|
||||||
|
cos=cos,
|
||||||
|
num_input_tokens=num_input_tokens,
|
||||||
|
head_dim=head_dim,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
attn_state=attn_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(metadata.has_prefill, has_prefill)
|
||||||
|
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
||||||
|
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||||
|
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
|
||||||
|
self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens))
|
||||||
|
self.assertIs(metadata.block_tables, block_tables)
|
||||||
|
self.assertIs(metadata.sin, sin)
|
||||||
|
self.assertIs(metadata.cos, cos)
|
||||||
|
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
|
||||||
|
self.assertIs(metadata.head_dim, head_dim)
|
||||||
|
self.assertIs(metadata.attn_mask, attn_mask)
|
||||||
|
self.assertEqual(metadata.attn_state, attn_state)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendSFAMetadataBuilder(TestBase):
|
||||||
|
|
||||||
|
def test_ascend_sfa_metadata_builder_default(self):
|
||||||
|
kv_cache_spec = MagicMock()
|
||||||
|
layer_names = ["layer1", "layer2"]
|
||||||
|
vllm_config = MagicMock()
|
||||||
|
speculative_config = MagicMock()
|
||||||
|
speculative_config.num_speculative_tokens = 4
|
||||||
|
vllm_config.speculative_config = speculative_config
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
|
||||||
|
layer_names=layer_names,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
assert builder.aclgraph_support == AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
|
assert builder.device == device
|
||||||
|
assert builder.vllm_config == vllm_config
|
||||||
|
|
||||||
|
def test_ascend_sfa_metadata_builder_build(self):
|
||||||
|
kv_cache_spec = MagicMock()
|
||||||
|
layer_names = ["layer1", "layer2"]
|
||||||
|
vllm_config = MagicMock()
|
||||||
|
speculative_config = MagicMock()
|
||||||
|
speculative_config.num_speculative_tokens = 4
|
||||||
|
vllm_config.speculative_config = speculative_config
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
|
||||||
|
layer_names=layer_names,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
common_attn_metadata = MagicMock()
|
||||||
|
common_attn_metadata.num_reqs = 10
|
||||||
|
common_attn_metadata.num_actual_tokens = 100
|
||||||
|
common_attn_metadata.query_start_loc = torch.tensor(
|
||||||
|
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
|
||||||
|
common_attn_metadata.query_start_loc_cpu = torch.tensor(
|
||||||
|
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
|
||||||
|
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
|
||||||
|
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
|
||||||
|
common_attn_metadata.positions = torch.randn(100)
|
||||||
|
common_attn_metadata.attn_mask = None
|
||||||
|
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
|
||||||
|
|
||||||
|
model = MagicMock()
|
||||||
|
model.model.layers = [MagicMock() for _ in range(10)]
|
||||||
|
model.model.start_layer = 0
|
||||||
|
|
||||||
|
metadata = builder.build(
|
||||||
|
common_prefix_len=10,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(metadata, AscendSFAMetadata)
|
||||||
|
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
|
||||||
|
assert metadata.slot_mapping.shape == (100, 4, 1024)
|
||||||
|
|
||||||
|
def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
|
||||||
|
kv_cache_spec = MagicMock()
|
||||||
|
layer_names = ["layer1", "layer2"]
|
||||||
|
vllm_config = MagicMock()
|
||||||
|
speculative_config = MagicMock()
|
||||||
|
speculative_config.num_speculative_tokens = 4
|
||||||
|
vllm_config.speculative_config = speculative_config
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
|
||||||
|
layer_names=layer_names,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
common_attn_metadata = MagicMock()
|
||||||
|
common_attn_metadata.num_reqs = 10
|
||||||
|
common_attn_metadata.num_actual_tokens = 100
|
||||||
|
common_attn_metadata.query_start_loc = torch.tensor(
|
||||||
|
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
|
||||||
|
common_attn_metadata.query_start_loc_cpu = torch.tensor(
|
||||||
|
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
|
||||||
|
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
|
||||||
|
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
|
||||||
|
common_attn_metadata.positions = torch.randn(100)
|
||||||
|
common_attn_metadata.attn_mask = None
|
||||||
|
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
|
||||||
|
|
||||||
|
model = MagicMock()
|
||||||
|
model.model.layers = [MagicMock() for _ in range(10)]
|
||||||
|
model.model.start_layer = 0
|
||||||
|
|
||||||
|
attn_metadata = builder.build_for_graph_capture(
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
attn_state=AscendAttentionState.DecodeOnly,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(attn_metadata, AscendSFAMetadata)
|
||||||
|
assert attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||||
@@ -91,7 +91,7 @@ M = TypeVar("M", bound=AscendSFAMetadata)
|
|||||||
class AscendSFAMetadataBuilder:
|
class AscendSFAMetadataBuilder:
|
||||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
AttentionCGSupport.NEVER
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
@@ -189,6 +189,26 @@ class AscendSFAMetadataBuilder:
|
|||||||
sin=sin,
|
sin=sin,
|
||||||
cos=cos)
|
cos=cos)
|
||||||
|
|
||||||
|
def build_for_graph_capture(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||||
|
model: Optional[nn.Module] = None,
|
||||||
|
):
|
||||||
|
if attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
attn_metadata = self.build(
|
||||||
|
common_prefix_len=0,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently we only support building dummy metadata for DecodeOnly state"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_metadata.attn_state = attn_state
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
class AscendSFAImpl(MLAAttentionImpl):
|
class AscendSFAImpl(MLAAttentionImpl):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1894,7 +1894,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||||||
|
and not self.use_sparse:
|
||||||
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
|
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
@@ -2687,11 +2688,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
[0] * dcp_world_size for _ in range(pcp_world_size)
|
[0] * dcp_world_size for _ in range(pcp_world_size)
|
||||||
] for _ in range(num_tokens)]
|
] for _ in range(num_tokens)]
|
||||||
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
if self.speculative_config:
|
||||||
query_start_loc=torch.tensor(
|
query_start_loc = torch.tensor(
|
||||||
[0] + self.actual_seq_lengths_q[:num_reqs],
|
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.int32),
|
dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||||
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||||
1],
|
1],
|
||||||
seq_lens_cpu=self.seq_lens_cpu,
|
seq_lens_cpu=self.seq_lens_cpu,
|
||||||
@@ -2737,7 +2742,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
assert forward_context is not None
|
assert forward_context is not None
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||||
not forward_context.capturing:
|
not forward_context.capturing and not self.use_sparse:
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
# FIXME: Try using `auto_dispatch_capture=True`
|
# FIXME: Try using `auto_dispatch_capture=True`
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user