diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 65e271c..735462f 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -96,7 +96,6 @@ class TestAscendAttentionMetadataBuilder(TestBase): num_reqs = 2 num_actual_tokens = 10 max_query_len = 5 - common_prefix_len = 1 self.mock_runner.input_batch.block_table = [MagicMock()] self.mock_runner.input_batch.block_table[ @@ -114,8 +113,11 @@ class TestAscendAttentionMetadataBuilder(TestBase): mock_nd_to_nz_2d.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build(num_reqs, num_actual_tokens, max_query_len, - common_prefix_len) + self.builder.build( + num_reqs, + num_actual_tokens, + max_query_len, + ) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('torch_npu.npu_format_cast') @@ -148,7 +150,7 @@ class TestAscendAttentionMetadataBuilder(TestBase): mock_nd_to_nz_spec.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0) + self.builder.build(num_reqs, num_actual_tokens, max_query_len) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) @@ -169,7 +171,7 @@ class TestAscendAttentionMetadataBuilder(TestBase): self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) - self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0) + self.builder.build(num_reqs, num_actual_tokens, max_query_len) class TestAscendAttentionBackendImpl(TestBase): @@ -201,7 +203,9 @@ class TestAscendAttentionBackendImpl(TestBase): alibi_slopes=None, sliding_window=None, kv_cache_dtype="float16", - attn_type=self.attention_type.DECODER) + logits_soft_cap=None, + attn_type=self.attention_type.DECODER, + kv_sharing_target_layer_name=None) self.impl_192 = AscendAttentionBackendImpl( num_heads=8, @@ -211,16 +215,21 @@ class TestAscendAttentionBackendImpl(TestBase): alibi_slopes=None, sliding_window=None, kv_cache_dtype="float16", - attn_type=self.attention_type.DECODER) + logits_soft_cap=None, + attn_type=self.attention_type.DECODER, + kv_sharing_target_layer_name=None) - self.impl_error = AscendAttentionBackendImpl(num_heads=8, - head_size=192, - scale=1.0, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="float16", - attn_type=None) + self.impl_error = AscendAttentionBackendImpl( + num_heads=8, + head_size=192, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + logits_soft_cap=None, + attn_type=None, + kv_sharing_target_layer_name=None) @patch('torch.ops.vllm.unified_ascend_attention_with_output') def test_forward_trace_flag_true(self, mock_unified_attention): diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 9f2bd9b..fa2a528 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -130,10 +130,8 @@ class AscendMetadata: query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor - # max value of number of tokens across dp group max_num_tokens_across_dp: int = 0 - # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -141,18 +139,9 @@ class AscendMetadata: # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor = None - # TODO: Indicates whether there are only prefill requests. - # FlashAttention can be used when there are only prefill requests. - # FlashAttention has better performance than PageAtttention, - # but it does not support decode requests. - is_only_prefill: bool = False # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_mask: Optional[torch.Tensor] = None - - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - with_prefill_across_dp: bool = False @@ -169,7 +158,6 @@ class AscendAttentionMetadataBuilder: num_reqs, num_actual_tokens, max_query_len, - common_prefix_len, max_num_tokens_across_dp: int = 0, with_prefill_across_dp: bool = False): @@ -224,10 +212,10 @@ class AscendAttentionBackendImpl(AttentionImpl): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 0c50290..84cfbd0 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -37,7 +37,7 @@ class AscendAttentionTorchairBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ASCEND" + return "ASCEND_TORCHAIR" @staticmethod def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]: @@ -129,10 +129,8 @@ class AscendTorchairMetadata: query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor - # max value of number of tokens across dp group max_num_tokens_across_dp: int = 0 - # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -140,20 +138,10 @@ class AscendTorchairMetadata: # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor = None - # TODO: Indicates whether there are only prefill requests. - # FlashAttention can be used when there are only prefill requests. - # FlashAttention has better performance than PageAtttention, - # but it does not support decode requests. - is_only_prefill: bool = False # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_mask: Optional[torch.Tensor] = None - - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - with_prefill_across_dp: bool = False - decode: Optional[AscendDecodeMetadata] = None @@ -236,7 +224,6 @@ class AscendAttentionTorchairMetadataBuilder: num_reqs, num_actual_tokens, max_query_len, - common_prefix_len, graph_pad_size: int = -1, max_num_tokens_across_dp: int = 0, with_prefill_across_dp: bool = False): @@ -335,10 +322,10 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 3accb32..0645a78 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -31,27 +31,13 @@ if TYPE_CHECKING: _ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128] -@dataclass -class CommonAttentionMetadata: - """ - Attention metadata attributes that can be shared by layers in different KV - cache groups and thus having different block table. - """ - - query_start_loc: torch.Tensor - """(batch_size + 1,), the start location of each request in query Tensor""" - seq_lens: torch.Tensor - """(batch_size,), the length of each request including both computed tokens - and newly scheduled tokens""" - - class AscendMLABackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod def get_name() -> str: - return "VLLM_ASCEND_MLA" + return "ASCEND_MLA" @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: @@ -368,11 +354,10 @@ class AscendMLAMetadataBuilder: num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_attn_metadata: CommonAttentionMetadata, - common_prefix_len: Optional[int] = None, graph_pad_size: int = -1, max_num_tokens_across_dp: int = 0, with_prefill_across_dp: bool = False, + query_start_loc: torch.Tensor = None, ) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs @@ -394,7 +379,6 @@ class AscendMLAMetadataBuilder: seq_lens = seq_lens_cpu max_query_len = query_lens.max().item() max_seq_lens = seq_lens.max().item() - query_start_loc = common_attn_metadata.query_start_loc prefill_metadata = None chunked_context_metadata = None @@ -403,7 +387,6 @@ class AscendMLAMetadataBuilder: tokens_start = self._num_decode_tokens max_query_len = query_lens[tokens_start:].max().item() max_seq_lens = seq_lens[tokens_start:].max().item() - query_start_loc = common_attn_metadata.query_start_loc prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] @@ -539,7 +522,7 @@ class AscendMLAImpl(MLAAttentionImpl): kv_cache_dtype: str, logits_soft_cap: Optional[float], attn_type: str, - kv_sharing_target_layer_name: Optional[str] = None, + kv_sharing_target_layer_name: Optional[str], **kwargs, ) -> None: self.num_heads = num_heads diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index fc074d5..3ce0d87 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -132,7 +132,6 @@ class EagleProposer: num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, - common_prefix_len=0, ) if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8212c36..c934d22 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -75,8 +75,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata -from vllm_ascend.attention.mla_v1 import (AscendMLAMetadata, - CommonAttentionMetadata) +from vllm_ascend.attention.mla_v1 import AscendMLAMetadata from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.torchair.utils import (check_torchair_cache_exist, @@ -694,15 +693,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - attn_metadata_i = self.attn_metadata_builder.build( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, ) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1049,27 +1043,22 @@ class NPUModelRunner(LoRAModelRunnerMixin): extra_builder_kwargs['graph_pad_size'] = graph_pad_size if self.vllm_config.model_config.use_mla: - query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, seq_lens=seq_lens) + extra_builder_kwargs[ + "query_start_loc"] = self.query_start_loc[:num_reqs + 1] attn_metadata = self.attn_metadata_builder.build( # type: ignore num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - common_attn_metadata=common_attn_metadata, - common_prefix_len=None, **extra_builder_kwargs, ) + attn_metadata.num_input_tokens = num_input_tokens else: attn_metadata = self.attn_metadata_builder.build( # type: ignore num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - common_prefix_len=None, **extra_builder_kwargs, ) - attn_metadata.num_input_tokens = num_input_tokens # Prepare input_ids token_indices = (positions_np + diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 5b88e7e..6577bb8 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -8,7 +8,6 @@ from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, set_default_torch_dtype) from vllm.v1.sample.metadata import SamplingMetadata -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP @@ -100,11 +99,6 @@ class MtpProposer: query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() - seq_lens = (target_positions[last_token_indices] + 1) - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=cu_num_tokens, seq_lens=seq_lens) - # FIXME: reorder_batch() needs to be called before build() # because fields of attn_metadata_builder needs to be updated. # However, currently reorder_batch() takes input_batch and @@ -120,8 +114,7 @@ class MtpProposer: num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, + query_start_loc=cu_num_tokens, ) with set_forward_context(attn_metadata, self.vllm_config):