[3/N][Refactor] Move torchair_attention to torchair dir (#2017)
### What this PR does / why we need it?
1. Move `torchair_attention` to `torchair` dir.
2. Make `AscendAttentionTorchairBackend` extend `AscendAttentionBackend`
to reduce duplicate methods.
3. Make `AscendTorchairMetadata` extend `AscendMetadata` to reduce
duplicate properties.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
0933f9d518
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -444,7 +444,7 @@ class TestNPUPlatform(TestBase):
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
|
||||
)
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
|
||||
@@ -169,7 +169,9 @@ class AscendAttentionMetadataBuilder:
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
enable_dbo_across_dp: bool = False,
|
||||
is_only_prefill: bool = False):
|
||||
is_only_prefill: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
|
||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||
)
|
||||
|
||||
@@ -218,7 +218,7 @@ class NPUPlatform(Platform):
|
||||
if use_mla:
|
||||
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
|
||||
elif use_torchair:
|
||||
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
||||
return "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
|
||||
else:
|
||||
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
||||
|
||||
|
||||
@@ -21,18 +21,19 @@ from typing import List, Optional, Tuple, Type
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackend(AttentionBackend):
|
||||
class AscendAttentionTorchairBackend(AscendAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
@@ -47,10 +48,6 @@ class AscendAttentionTorchairBackend(AttentionBackend):
|
||||
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
|
||||
return AscendTorchairMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
|
||||
return AscendAttentionTorchairMetadataBuilder
|
||||
@@ -73,36 +70,6 @@ class AscendAttentionTorchairBackend(AttentionBackend):
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
dst_kv_cache: List[torch.Tensor],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
||||
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
||||
src_indices = src_to_dst[:, 0]
|
||||
dst_indices = src_to_dst[:, 1]
|
||||
|
||||
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
src_indices = src_to_dists[:, 0]
|
||||
dst_indices = src_to_dists[:, 1]
|
||||
|
||||
for kv_cache in kv_caches:
|
||||
key_caches = kv_cache[0]
|
||||
value_caches = kv_cache[1]
|
||||
key_caches[dst_indices] = key_caches[src_indices]
|
||||
value_caches[dst_indices] = value_caches[src_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendDecodeMetadata:
|
||||
@@ -117,40 +84,15 @@ class AscendDecodeMetadata:
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendTorchairMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
block_tables: torch.Tensor
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
query_start_loc: torch.Tensor
|
||||
query_lens: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor = None
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
class AscendTorchairMetadata(AscendMetadata):
|
||||
|
||||
decode: Optional[AscendDecodeMetadata] = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
|
||||
class AscendAttentionTorchairMetadataBuilder:
|
||||
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
super().__init__(runner)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
@@ -222,11 +164,16 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
graph_pad_size: int = -1,
|
||||
enable_dbo_across_dp: bool = False,
|
||||
is_only_prefill: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
|
||||
if 'graph_pad_size' in kwargs:
|
||||
graph_pad_size = kwargs['graph_pad_size']
|
||||
else:
|
||||
graph_pad_size = -1 # default value
|
||||
|
||||
device = self.runner.device
|
||||
|
||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||
@@ -78,7 +78,6 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
DummyCommImpl,
|
||||
@@ -86,6 +85,7 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
ProfileExecuteDuration, is_310p,
|
||||
maybe_converting_weight_acl_format)
|
||||
|
||||
Reference in New Issue
Block a user