diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py new file mode 100644 index 00000000..88f39701 --- /dev/null +++ b/tests/ut/attention/test_sfa_v1.py @@ -0,0 +1,185 @@ +from unittest.mock import MagicMock + +import torch +from vllm.v1.attention.backends.utils import AttentionCGSupport + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.sfa_v1 import (AscendSFABackend, AscendSFAImpl, + AscendSFAMetadata, + AscendSFAMetadataBuilder) + + +class TestAscendSFABackend(TestBase): + + def test_get_name(self): + self.assertEqual(AscendSFABackend.get_name(), "ASCEND_SFA") + + def test_get_metadata_cls(self): + self.assertEqual(AscendSFABackend.get_metadata_cls(), + AscendSFAMetadata) + + def test_get_builder_cls(self): + self.assertEqual(AscendSFABackend.get_builder_cls(), + AscendSFAMetadataBuilder) + + def test_get_kv_cache_shape(self): + result = AscendSFABackend.get_kv_cache_shape(2, 4, 8, 128) + self.assertEqual(result, (2, 4, 8, 128)) + + def test_get_impl_cls(self): + result = AscendSFABackend.get_impl_cls() + self.assertEqual(result, AscendSFAImpl) + + +class TestAscendSFAMetadata(TestBase): + + def test_ascend_sfa_metadata_default(self): + has_prefill = True + num_actual_tokens = 100 + slot_mapping = torch.randn(100, 4, 1024) + seq_lens = torch.tensor([30, 50]) + cum_query_lens = torch.tensor([0, 30, 80]) + block_tables = torch.randint(0, 100, (100, 4)) + + rope_dim = 32 + max_seq_len = int(seq_lens.max().item()) + sin = torch.randn(max_seq_len, rope_dim) + cos = torch.randn(max_seq_len, rope_dim) + + num_input_tokens = 2 + head_dim = None + attn_mask = None + attn_state = AscendAttentionState.ChunkedPrefill + + metadata = AscendSFAMetadata( + has_prefill=has_prefill, + num_actual_tokens=num_actual_tokens, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + cum_query_lens=cum_query_lens, + block_tables=block_tables, + sin=sin, + cos=cos, + num_input_tokens=num_input_tokens, + head_dim=head_dim, + attn_mask=attn_mask, + attn_state=attn_state, + ) + + self.assertEqual(metadata.has_prefill, has_prefill) + self.assertEqual(metadata.num_actual_tokens, num_actual_tokens) + self.assertIs(metadata.slot_mapping, slot_mapping) + self.assertTrue(torch.equal(metadata.seq_lens, seq_lens)) + self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens)) + self.assertIs(metadata.block_tables, block_tables) + self.assertIs(metadata.sin, sin) + self.assertIs(metadata.cos, cos) + self.assertEqual(metadata.num_input_tokens, num_input_tokens) + self.assertIs(metadata.head_dim, head_dim) + self.assertIs(metadata.attn_mask, attn_mask) + self.assertEqual(metadata.attn_state, attn_state) + + +class TestAscendSFAMetadataBuilder(TestBase): + + def test_ascend_sfa_metadata_builder_default(self): + kv_cache_spec = MagicMock() + layer_names = ["layer1", "layer2"] + vllm_config = MagicMock() + speculative_config = MagicMock() + speculative_config.num_speculative_tokens = 4 + vllm_config.speculative_config = speculative_config + device = torch.device("cpu") + + builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec, + layer_names=layer_names, + vllm_config=vllm_config, + device=device) + + assert builder.aclgraph_support == AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + assert builder.device == device + assert builder.vllm_config == vllm_config + + def test_ascend_sfa_metadata_builder_build(self): + kv_cache_spec = MagicMock() + layer_names = ["layer1", "layer2"] + vllm_config = MagicMock() + speculative_config = MagicMock() + speculative_config.num_speculative_tokens = 4 + vllm_config.speculative_config = speculative_config + device = torch.device("cpu") + + builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec, + layer_names=layer_names, + vllm_config=vllm_config, + device=device) + + common_attn_metadata = MagicMock() + common_attn_metadata.num_reqs = 10 + common_attn_metadata.num_actual_tokens = 100 + common_attn_metadata.query_start_loc = torch.tensor( + [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]) + common_attn_metadata.query_start_loc_cpu = torch.tensor( + [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]) + common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024) + common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10) + common_attn_metadata.positions = torch.randn(100) + common_attn_metadata.attn_mask = None + common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + common_attn_metadata.block_table_tensor = torch.randn(100, 4) + + model = MagicMock() + model.model.layers = [MagicMock() for _ in range(10)] + model.model.start_layer = 0 + + metadata = builder.build( + common_prefix_len=10, + common_attn_metadata=common_attn_metadata, + model=model, + ) + + assert isinstance(metadata, AscendSFAMetadata) + assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens + assert metadata.slot_mapping.shape == (100, 4, 1024) + + def test_ascend_sfa_metadata_builder_build_for_graph_capture(self): + kv_cache_spec = MagicMock() + layer_names = ["layer1", "layer2"] + vllm_config = MagicMock() + speculative_config = MagicMock() + speculative_config.num_speculative_tokens = 4 + vllm_config.speculative_config = speculative_config + device = torch.device("cpu") + + builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec, + layer_names=layer_names, + vllm_config=vllm_config, + device=device) + + common_attn_metadata = MagicMock() + common_attn_metadata.num_reqs = 10 + common_attn_metadata.num_actual_tokens = 100 + common_attn_metadata.query_start_loc = torch.tensor( + [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]) + common_attn_metadata.query_start_loc_cpu = torch.tensor( + [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]) + common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024) + common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10) + common_attn_metadata.positions = torch.randn(100) + common_attn_metadata.attn_mask = None + common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + common_attn_metadata.block_table_tensor = torch.randn(100, 4) + + model = MagicMock() + model.model.layers = [MagicMock() for _ in range(10)] + model.model.start_layer = 0 + + attn_metadata = builder.build_for_graph_capture( + common_attn_metadata=common_attn_metadata, + attn_state=AscendAttentionState.DecodeOnly, + model=model, + ) + + assert isinstance(attn_metadata, AscendSFAMetadata) + assert attn_metadata.attn_state == AscendAttentionState.DecodeOnly diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 0558b384..9747c2d1 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -91,7 +91,7 @@ M = TypeVar("M", bound=AscendSFAMetadata) class AscendSFAMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -189,6 +189,26 @@ class AscendSFAMetadataBuilder: sin=sin, cos=cos) + def build_for_graph_capture( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + model: Optional[nn.Module] = None, + ): + if attn_state == AscendAttentionState.DecodeOnly: + attn_metadata = self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + model=model, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly state" + ) + + attn_metadata.attn_state = attn_state + return attn_metadata + class AscendSFAImpl(MLAAttentionImpl): """ diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a85d4cda..faa222dd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1894,7 +1894,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ + and not self.use_sparse: # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead if self.vllm_config.model_config.use_mla: if self.pcp_size * self.dcp_size > 1: @@ -2687,11 +2688,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): [0] * dcp_world_size for _ in range(pcp_world_size) ] for _ in range(num_tokens)] long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=torch.tensor( + if self.speculative_config: + query_start_loc = torch.tensor( [0] + self.actual_seq_lengths_q[:num_reqs], device=self.device, - dtype=torch.int32), + dtype=torch.int32) + else: + query_start_loc = self.query_start_loc[:num_reqs + 1] + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=query_start_loc, query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], seq_lens_cpu=self.seq_lens_cpu, @@ -2737,7 +2742,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): forward_context = get_forward_context() assert forward_context is not None if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ - not forward_context.capturing: + not forward_context.capturing and not self.use_sparse: if self.vllm_config.model_config.use_mla: # FIXME: Try using `auto_dispatch_capture=True` if self.pcp_size * self.dcp_size > 1: