[3/N][Refactor] Move torchair_attention to torchair dir (#2017)
### What this PR does / why we need it?
1. Move `torchair_attention` to `torchair` dir.
2. Make `AscendAttentionTorchairBackend` extend `AscendAttentionBackend`
to reduce duplicate methods.
3. Make `AscendTorchairMetadata` extend `AscendMetadata` to reduce
duplicate properties.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
0933f9d518
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -444,7 +444,7 @@ class TestNPUPlatform(TestBase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result,
|
result,
|
||||||
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('vllm_ascend.platform.get_ascend_config')
|
@patch('vllm_ascend.platform.get_ascend_config')
|
||||||
|
|||||||
@@ -169,7 +169,9 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_actual_tokens,
|
num_actual_tokens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
enable_dbo_across_dp: bool = False,
|
enable_dbo_across_dp: bool = False,
|
||||||
is_only_prefill: bool = False):
|
is_only_prefill: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -218,7 +218,7 @@ class NPUPlatform(Platform):
|
|||||||
if use_mla:
|
if use_mla:
|
||||||
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
|
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
|
||||||
elif use_torchair:
|
elif use_torchair:
|
||||||
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
return "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
|
||||||
else:
|
else:
|
||||||
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
||||||
|
|
||||||
|
|||||||
@@ -21,18 +21,19 @@ from typing import List, Optional, Tuple, Type
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||||
AttentionLayer, AttentionType)
|
AttentionType)
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
|
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||||
|
AscendAttentionMetadataBuilder,
|
||||||
|
AscendAttentionState,
|
||||||
|
AscendMetadata)
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||||
nd_to_nz_2d)
|
nd_to_nz_2d)
|
||||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionTorchairBackend(AttentionBackend):
|
class AscendAttentionTorchairBackend(AscendAttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -47,10 +48,6 @@ class AscendAttentionTorchairBackend(AttentionBackend):
|
|||||||
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
|
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
|
||||||
return AscendTorchairMetadata
|
return AscendTorchairMetadata
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
|
||||||
return CommonAttentionState
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
|
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
|
||||||
return AscendAttentionTorchairMetadataBuilder
|
return AscendAttentionTorchairMetadataBuilder
|
||||||
@@ -73,36 +70,6 @@ class AscendAttentionTorchairBackend(AttentionBackend):
|
|||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def swap_blocks(
|
|
||||||
src_kv_cache: List[torch.Tensor],
|
|
||||||
dst_kv_cache: List[torch.Tensor],
|
|
||||||
src_to_dst: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
||||||
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
||||||
src_indices = src_to_dst[:, 0]
|
|
||||||
dst_indices = src_to_dst[:, 1]
|
|
||||||
|
|
||||||
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
|
||||||
dst_key_cache.device)
|
|
||||||
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
|
||||||
dst_key_cache.device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def copy_blocks(
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
src_to_dists: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
src_indices = src_to_dists[:, 0]
|
|
||||||
dst_indices = src_to_dists[:, 1]
|
|
||||||
|
|
||||||
for kv_cache in kv_caches:
|
|
||||||
key_caches = kv_cache[0]
|
|
||||||
value_caches = kv_cache[1]
|
|
||||||
key_caches[dst_indices] = key_caches[src_indices]
|
|
||||||
value_caches[dst_indices] = value_caches[src_indices]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendDecodeMetadata:
|
class AscendDecodeMetadata:
|
||||||
@@ -117,40 +84,15 @@ class AscendDecodeMetadata:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendTorchairMetadata:
|
class AscendTorchairMetadata(AscendMetadata):
|
||||||
num_actual_tokens: int # Number of tokens excluding padding.
|
|
||||||
# (batch_size, max_blocks_per_seq).
|
|
||||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
|
||||||
block_tables: torch.Tensor
|
|
||||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
||||||
# the computed tokens + new tokens None if it is a decoding.
|
|
||||||
query_start_loc: torch.Tensor
|
|
||||||
query_lens: torch.Tensor
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
# Maximum query length in the batch. None for decoding.
|
|
||||||
max_query_len: Optional[int] = None
|
|
||||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
|
||||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
|
||||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
|
||||||
# in block 0, and 1st slot in block 1, respectively.
|
|
||||||
slot_mapping: torch.Tensor = None
|
|
||||||
# Current state of this attention run.
|
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
decode: Optional[AscendDecodeMetadata] = None
|
decode: Optional[AscendDecodeMetadata] = None
|
||||||
|
|
||||||
enable_dbo_across_dp: bool = False
|
|
||||||
|
|
||||||
|
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||||
class AscendAttentionTorchairMetadataBuilder:
|
|
||||||
|
|
||||||
def __init__(self, runner):
|
def __init__(self, runner):
|
||||||
self.runner = runner
|
super().__init__(runner)
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _get_graph_runner_block_tables(
|
def _get_graph_runner_block_tables(
|
||||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -222,11 +164,16 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
num_reqs,
|
num_reqs,
|
||||||
num_actual_tokens,
|
num_actual_tokens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
graph_pad_size: int = -1,
|
|
||||||
enable_dbo_across_dp: bool = False,
|
enable_dbo_across_dp: bool = False,
|
||||||
|
is_only_prefill: bool = False,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
|
if 'graph_pad_size' in kwargs:
|
||||||
|
graph_pad_size = kwargs['graph_pad_size']
|
||||||
|
else:
|
||||||
|
graph_pad_size = -1 # default value
|
||||||
|
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
|
|
||||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||||
@@ -78,7 +78,6 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
|||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||||
AscendMetadata)
|
AscendMetadata)
|
||||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
|
||||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||||
DummyCommImpl,
|
DummyCommImpl,
|
||||||
@@ -86,6 +85,7 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
|||||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||||
|
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
ProfileExecuteDuration, is_310p,
|
ProfileExecuteDuration, is_310p,
|
||||||
maybe_converting_weight_acl_format)
|
maybe_converting_weight_acl_format)
|
||||||
|
|||||||
Reference in New Issue
Block a user