diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index e6642b5..940d07f 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -444,7 +444,7 @@ class TestNPUPlatform(TestBase): ) self.assertEqual( result, - "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend" + "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend" ) @patch('vllm_ascend.platform.get_ascend_config') diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 15a7759..81a3375 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -169,7 +169,9 @@ class AscendAttentionMetadataBuilder: num_actual_tokens, max_query_len, enable_dbo_across_dp: bool = False, - is_only_prefill: bool = False): + is_only_prefill: bool = False, + *args, + **kwargs): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 7f21f26..0299acc 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -218,7 +218,7 @@ class NPUPlatform(Platform): if use_mla: return "vllm_ascend.attention.mla_v1.AscendMLABackend" elif use_torchair: - return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend" + return "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend" else: return "vllm_ascend.attention.attention_v1.AscendAttentionBackend" diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/torchair/torchair_attention.py similarity index 85% rename from vllm_ascend/attention/attention_v1_torchair.py rename to vllm_ascend/torchair/torchair_attention.py index 4d84bac..a3fda61 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -21,18 +21,19 @@ from typing import List, Optional, Tuple, Type import numpy as np import torch import torch_npu -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, AttentionType) -from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, + AttentionType) +from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, + AscendAttentionMetadataBuilder, + AscendAttentionState, + AscendMetadata) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d) -from vllm_ascend.worker.npu_input_batch import InputBatch -class AscendAttentionTorchairBackend(AttentionBackend): +class AscendAttentionTorchairBackend(AscendAttentionBackend): accept_output_buffer: bool = True @staticmethod @@ -47,10 +48,6 @@ class AscendAttentionTorchairBackend(AttentionBackend): def get_metadata_cls() -> Type["AscendTorchairMetadata"]: return AscendTorchairMetadata - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - @staticmethod def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]: return AscendAttentionTorchairMetadataBuilder @@ -73,36 +70,6 @@ class AscendAttentionTorchairBackend(AttentionBackend): ) -> Tuple[int, ...]: return (2, num_blocks, block_size, num_kv_heads * head_size) - @staticmethod - def swap_blocks( - src_kv_cache: List[torch.Tensor], - dst_kv_cache: List[torch.Tensor], - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] - dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1] - src_indices = src_to_dst[:, 0] - dst_indices = src_to_dst[:, 1] - - dst_key_cache[dst_indices] = src_key_cache[src_indices].to( - dst_key_cache.device) - dst_value_cache[dst_indices] = src_value_cache[src_indices].to( - dst_key_cache.device) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - src_indices = src_to_dists[:, 0] - dst_indices = src_to_dists[:, 1] - - for kv_cache in kv_caches: - key_caches = kv_cache[0] - value_caches = kv_cache[1] - key_caches[dst_indices] = key_caches[src_indices] - value_caches[dst_indices] = value_caches[src_indices] - @dataclass class AscendDecodeMetadata: @@ -117,40 +84,15 @@ class AscendDecodeMetadata: @dataclass -class AscendTorchairMetadata: - num_actual_tokens: int # Number of tokens excluding padding. - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - block_tables: torch.Tensor - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - query_start_loc: torch.Tensor - query_lens: torch.Tensor - seq_lens: torch.Tensor - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor = None - # Current state of this attention run. - attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill - attn_mask: Optional[torch.Tensor] = None +class AscendTorchairMetadata(AscendMetadata): decode: Optional[AscendDecodeMetadata] = None - enable_dbo_across_dp: bool = False - -class AscendAttentionTorchairMetadataBuilder: +class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder): def __init__(self, runner): - self.runner = runner - - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return False + super().__init__(runner) def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: @@ -222,11 +164,16 @@ class AscendAttentionTorchairMetadataBuilder: num_reqs, num_actual_tokens, max_query_len, - graph_pad_size: int = -1, enable_dbo_across_dp: bool = False, + is_only_prefill: bool = False, *args, **kwargs): + if 'graph_pad_size' in kwargs: + graph_pad_size = kwargs['graph_pad_size'] + else: + graph_pad_size = -1 # default value + device = self.runner.device block_table = self.runner.input_batch.block_table[0].get_device_tensor( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ebf76eb..f4450d3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -78,7 +78,6 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) -from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata from vllm_ascend.attention.mla_v1 import AscendMLAMetadata from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, DummyCommImpl, @@ -86,6 +85,7 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler +from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, ProfileExecuteDuration, is_310p, maybe_converting_weight_acl_format)