v0.10.1rc1
This commit is contained in:
0
vllm_ascend/attention/__init__.py
Normal file
0
vllm_ascend/attention/__init__.py
Normal file
93
vllm_ascend/attention/attention_mask.py
Normal file
93
vllm_ascend/attention/attention_mask.py
Normal file
@@ -0,0 +1,93 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
|
||||
def _generate_attn_mask(max_seq_len, dtype):
|
||||
# Construct lower triangle matrix.
|
||||
mask_flag = torch.tril(
|
||||
torch.ones((max_seq_len, max_seq_len),
|
||||
dtype=torch.bool)).view(max_seq_len, max_seq_len)
|
||||
# Create upper triangle matrix used to mark mask positions.
|
||||
mask_flag = ~mask_flag
|
||||
# Currently for fp16 dtype, the mask value should be set to -inf.
|
||||
# TODO: Eliminate this part in the future.
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
||||
mask_flag, mask_value).to(dtype)
|
||||
return attn_mask
|
||||
|
||||
|
||||
class AttentionMaskBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
attn_mask = _generate_attn_mask(max_seq_len, dtype)
|
||||
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
|
||||
@staticmethod
|
||||
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
|
||||
if dtype == torch.float16:
|
||||
mask_scale_factor = 1
|
||||
elif dtype == torch.bfloat16:
|
||||
mask_scale_factor = -10000
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current operation now only supports data types: torch.float16 and "
|
||||
"torch.bfloat16. Please ensure the input is of one of these types."
|
||||
)
|
||||
return mask_scale_factor
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device)
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
position: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
raise ValueError(
|
||||
"splitfuse_attn_mask now only supports bf16 and fp16")
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
|
||||
attn_mask = torch.index_select(self.attn_mask_cache,
|
||||
dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
attn_mask *= mask_scale_factor
|
||||
return attn_mask.contiguous().to(device, non_blocking=True)
|
||||
|
||||
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
|
||||
if seqlen > self._seq_len_cached:
|
||||
self._seq_len_cached = seqlen
|
||||
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
|
||||
if self.attn_mask_cache.dtype != dtype:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
|
||||
604
vllm_ascend/attention/attention_v1.py
Normal file
604
vllm_ascend/attention/attention_v1.py
Normal file
@@ -0,0 +1,604 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
||||
return AscendAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AscendMetadata"]:
|
||||
return AscendMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
return AscendAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if is_310p():
|
||||
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
||||
16)
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
dst_kv_cache: List[torch.Tensor],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
||||
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
||||
src_indices = src_to_dst[:, 0]
|
||||
dst_indices = src_to_dst[:, 1]
|
||||
|
||||
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
src_indices = src_to_dists[:, 0]
|
||||
dst_indices = src_to_dists[:, 1]
|
||||
|
||||
for kv_cache in kv_caches:
|
||||
key_caches = kv_cache[0]
|
||||
value_caches = kv_cache[1]
|
||||
key_caches[dst_indices] = key_caches[src_indices]
|
||||
value_caches[dst_indices] = value_caches[src_indices]
|
||||
|
||||
|
||||
class AscendAttentionState(Enum):
|
||||
PrefillNoCache = 0
|
||||
PrefillCacheHit = 1
|
||||
DecodeOnly = 2
|
||||
ChunkedPrefill = 3
|
||||
SpecDecoding = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata:
|
||||
|
||||
# **************************** Basic Properties ************************** #
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
# Number of tokens excluding padding.
|
||||
num_actual_tokens: int = 0
|
||||
|
||||
# The sequence length per sequence. Sequence length means the computed
|
||||
# tokens + new tokens (is None if it is a decoding).
|
||||
# (batch_size,)
|
||||
seq_lens: torch.Tensor = None
|
||||
|
||||
query_start_loc: torch.Tensor = None
|
||||
query_lens: torch.Tensor = None
|
||||
# Maximum query length in the batch (None for decoding).
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# ********************** KV Cache Related Properties ********************* #
|
||||
# Block addresses per sequence (Seq id -> list of physical block).
|
||||
# (batch_size, max_blocks_per_seq)
|
||||
block_tables: torch.Tensor = None
|
||||
|
||||
# The indices of the token slots that input tokens will be stored into.
|
||||
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
|
||||
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
|
||||
# and 1st slot in block 1, respectively.
|
||||
# (num_tokens,)
|
||||
slot_mapping: torch.Tensor = None
|
||||
|
||||
# *************************** Other Properties *************************** #
|
||||
enable_dbo_across_dp: bool = False
|
||||
is_only_prefill: bool = False
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||
vllm_config.cache_config.block_size)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
block_table[:num_reqs])
|
||||
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
self.device,
|
||||
non_blocking=
|
||||
True)
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
if is_310p():
|
||||
if attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
mask_nz = nd_to_nz_spec(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
is_only_prefill=common_attn_metadata.is_only_prefill)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.hidden_size = self.num_heads * self.head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
def _repeat_kv(self, hidden_states: torch.Tensor,
|
||||
n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, None, :, :].expand(
|
||||
num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
|
||||
head_dim)
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
num_tokens=0,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
|
||||
if is_310p():
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
value = aligned_16(value)
|
||||
output = aligned_16(output)
|
||||
# do reformat in case of broadcasted tensors
|
||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
if self.sliding_window is not None and \
|
||||
attn_metadata.attn_mask.shape[0] > self.sliding_window:
|
||||
|
||||
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
|
||||
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND",
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
actual_seq_lengths=attn_metadata.seq_lens,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
assert output is not None
|
||||
return output[:num_tokens, :, :]
|
||||
|
||||
def _forward_prefill_cache_hit(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_table,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_decode_only(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if is_310p():
|
||||
# seq_lens_tensor needs to be transferred to the device for 310P.
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
if self.sliding_window is not None:
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_size = 128
|
||||
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
||||
key = self.key_cache
|
||||
value = self.value_cache
|
||||
if self.key_cache is not None and self.value_cache is not None:
|
||||
block_size = self.key_cache.shape[1]
|
||||
key = self.key_cache.flatten(2, 3).contiguous()
|
||||
value = self.value_cache.flatten(2, 3).contiguous()
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSH",
|
||||
block_size=block_size,
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_v1_style(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Use chunked prefill for head size 192 scenario, like deepseek
|
||||
# paged_attention_splitfuse maybe crash at such scenario.
|
||||
# TODO: vanilla path will be removed after the kernel support
|
||||
# head_size 192 scenario.
|
||||
if self.head_size == 192:
|
||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device)
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device)
|
||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
||||
self.value_cache,
|
||||
attn_metadata.block_tables, cu_seqlen_q,
|
||||
cu_seqlen_k, max_seqlen_q, max_seqlen_k,
|
||||
self.scale, None, True)
|
||||
return output
|
||||
|
||||
# Use paged attention.
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
if is_310p():
|
||||
# Do reformat in case of broadcasted tensors.
|
||||
attn_metadata.attn_mask = \
|
||||
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
trace_flag: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache: shape = [key_cache, value_cache]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_int8 = len(
|
||||
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
ori_output = output
|
||||
if trace_flag:
|
||||
torch.ops.vllm.unified_ascend_attention_with_output(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
|
||||
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
output)
|
||||
|
||||
else:
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if len(kv_cache) > 1:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
# V0-Style scheduler situation.
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
output = self._forward_decode_only(query, attn_metadata,
|
||||
output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
output = self._forward_v1_style(query, attn_metadata, output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
|
||||
def unified_ascend_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output,
|
||||
trace_flag=False)
|
||||
return
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_ascend_attention_with_output",
|
||||
op_func=unified_ascend_attention_with_output,
|
||||
mutates_args=["output"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
1050
vllm_ascend/attention/mla_v1.py
Normal file
1050
vllm_ascend/attention/mla_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
95
vllm_ascend/attention/utils.py
Normal file
95
vllm_ascend/attention/utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens_cpu: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
max_query_len: int
|
||||
"""Max token number of request in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
"""decode token number per request"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
|
||||
slot_mapping_cpu: torch.Tensor
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
positions: torch.Tensor = None
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
is_only_prefill: bool = False
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: AscendCommonAttentionMetadata object containing the
|
||||
batch metadata.
|
||||
decode_threshold: The maximum query length to be considered a decode.
|
||||
|
||||
Returns:
|
||||
num_decodes: The number of decode requests.
|
||||
num_prefills: The number of prefill requests.
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
assert torch.all(query_lens[first_prefill:] >= decode_threshold)
|
||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
Reference in New Issue
Block a user