[3/N][Refactor] Move torchair_attention to torchair dir (#2017)

### What this PR does / why we need it?

1. Move `torchair_attention` to `torchair` dir.
2. Make `AscendAttentionTorchairBackend` extend `AscendAttentionBackend`
to reduce duplicate methods.
3. Make `AscendTorchairMetadata` extend `AscendMetadata` to reduce
duplicate properties.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
0933f9d518

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-08-19 10:25:22 +08:00
committed by GitHub
parent 2a763b8326
commit 83e0f41408
5 changed files with 23 additions and 74 deletions

View File

@@ -444,7 +444,7 @@ class TestNPUPlatform(TestBase):
)
self.assertEqual(
result,
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
)
@patch('vllm_ascend.platform.get_ascend_config')

View File

@@ -169,7 +169,9 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens,
max_query_len,
enable_dbo_across_dp: bool = False,
is_only_prefill: bool = False):
is_only_prefill: bool = False,
*args,
**kwargs):
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)

View File

@@ -218,7 +218,7 @@ class NPUPlatform(Platform):
if use_mla:
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
elif use_torchair:
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
return "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
else:
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"

View File

@@ -21,18 +21,19 @@ from typing import List, Optional, Tuple, Type
import numpy as np
import torch
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d)
from vllm_ascend.worker.npu_input_batch import InputBatch
class AscendAttentionTorchairBackend(AttentionBackend):
class AscendAttentionTorchairBackend(AscendAttentionBackend):
accept_output_buffer: bool = True
@staticmethod
@@ -47,10 +48,6 @@ class AscendAttentionTorchairBackend(AttentionBackend):
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
return AscendTorchairMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
return AscendAttentionTorchairMetadataBuilder
@@ -73,36 +70,6 @@ class AscendAttentionTorchairBackend(AttentionBackend):
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
dst_indices = src_to_dists[:, 1]
for kv_cache in kv_caches:
key_caches = kv_cache[0]
value_caches = kv_cache[1]
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]
@dataclass
class AscendDecodeMetadata:
@@ -117,40 +84,15 @@ class AscendDecodeMetadata:
@dataclass
class AscendTorchairMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: torch.Tensor
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None
class AscendTorchairMetadata(AscendMetadata):
decode: Optional[AscendDecodeMetadata] = None
enable_dbo_across_dp: bool = False
class AscendAttentionTorchairMetadataBuilder:
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
def __init__(self, runner):
self.runner = runner
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
super().__init__(runner)
def _get_graph_runner_block_tables(
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
@@ -222,11 +164,16 @@ class AscendAttentionTorchairMetadataBuilder:
num_reqs,
num_actual_tokens,
max_query_len,
graph_pad_size: int = -1,
enable_dbo_across_dp: bool = False,
is_only_prefill: bool = False,
*args,
**kwargs):
if 'graph_pad_size' in kwargs:
graph_pad_size = kwargs['graph_pad_size']
else:
graph_pad_size = -1 # default value
device = self.runner.device
block_table = self.runner.input_batch.block_table[0].get_device_tensor(

View File

@@ -78,7 +78,6 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
@@ -86,6 +85,7 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
maybe_converting_weight_acl_format)