2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
2025-04-17 19:31:50 +08:00
|
|
|
from enum import Enum
|
2025-09-16 01:17:42 +08:00
|
|
|
from typing import ClassVar, List, Optional, Tuple, Type
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-08-20 09:01:04 +08:00
|
|
|
import torch.nn as nn
|
2025-03-20 19:34:44 +08:00
|
|
|
import torch_npu
|
|
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
|
|
|
AttentionLayer, AttentionType)
|
2025-12-05 10:31:49 +08:00
|
|
|
from vllm.attention.backends.registry import (AttentionBackendEnum,
|
|
|
|
|
register_backend)
|
2025-12-06 09:33:28 +08:00
|
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
2025-11-24 17:08:20 +08:00
|
|
|
from vllm.utils.math_utils import cdiv
|
2025-09-22 17:14:28 +08:00
|
|
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
2025-04-19 17:38:18 +08:00
|
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
2025-09-16 01:17:42 +08:00
|
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-09-23 14:25:05 +08:00
|
|
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
2025-10-25 08:58:35 +08:00
|
|
|
split_decodes_and_prefills)
|
2025-10-14 16:10:09 +08:00
|
|
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
|
|
|
|
update_graph_params_workspaces)
|
2025-12-05 10:31:49 +08:00
|
|
|
from vllm_ascend.utils import weak_ref_tensors
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-11-26 11:48:58 +08:00
|
|
|
|
|
|
|
|
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
2025-03-20 19:34:44 +08:00
|
|
|
class AscendAttentionBackend(AttentionBackend):
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
accept_output_buffer: bool = True
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_name() -> str:
|
2025-11-26 11:48:58 +08:00
|
|
|
return "CUSTOM"
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
2025-12-06 09:33:28 +08:00
|
|
|
prefill_config = get_current_vllm_config().parallel_config
|
|
|
|
|
if (prefill_config.prefill_context_parallel_size > 1
|
|
|
|
|
or prefill_config.decode_context_parallel_size > 1):
|
|
|
|
|
from vllm_ascend.attention.attention_cp import \
|
|
|
|
|
AscendAttentionCPImpl
|
|
|
|
|
return AscendAttentionCPImpl
|
2025-03-20 19:34:44 +08:00
|
|
|
return AscendAttentionBackendImpl
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
2025-12-06 09:33:28 +08:00
|
|
|
prefill_config = get_current_vllm_config().parallel_config
|
|
|
|
|
if (prefill_config.prefill_context_parallel_size > 1
|
|
|
|
|
or prefill_config.decode_context_parallel_size > 1):
|
|
|
|
|
from vllm_ascend.attention.attention_cp import \
|
|
|
|
|
AscendAttentionCPMetadataBuilder
|
|
|
|
|
return AscendAttentionCPMetadataBuilder
|
2025-04-19 17:38:18 +08:00
|
|
|
return AscendAttentionMetadataBuilder
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_kv_cache_shape(
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
) -> Tuple[int, ...]:
|
2025-04-17 19:31:50 +08:00
|
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-06-28 18:51:07 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_bsh_kv_cache_shape(
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
) -> Tuple[int, ...]:
|
|
|
|
|
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def swap_blocks(
|
|
|
|
|
src_kv_cache: List[torch.Tensor],
|
|
|
|
|
dst_kv_cache: List[torch.Tensor],
|
|
|
|
|
src_to_dst: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
|
|
|
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
|
|
|
src_indices = src_to_dst[:, 0]
|
|
|
|
|
dst_indices = src_to_dst[:, 1]
|
|
|
|
|
|
|
|
|
|
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
|
|
|
|
dst_key_cache.device)
|
|
|
|
|
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
|
|
|
|
dst_key_cache.device)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def copy_blocks(
|
|
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
|
|
src_to_dists: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
src_indices = src_to_dists[:, 0]
|
|
|
|
|
dst_indices = src_to_dists[:, 1]
|
|
|
|
|
|
|
|
|
|
for kv_cache in kv_caches:
|
|
|
|
|
key_caches = kv_cache[0]
|
|
|
|
|
value_caches = kv_cache[1]
|
|
|
|
|
key_caches[dst_indices] = key_caches[src_indices]
|
|
|
|
|
value_caches[dst_indices] = value_caches[src_indices]
|
|
|
|
|
|
2025-09-16 01:17:42 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_supported_block_size() -> list[int]:
|
[BugFix][main] Adapted to torch_npu.npu_fused_infer_attention_score (#4025)
### What this PR does / why we need it?
Fixes a compatible bug with `torch_npu.npu_fused_infer_attention_score`
which is discribed in
https://github.com/vllm-project/vllm-ascend/issues/4020.
@momo609 tells us this solution.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
The environment is same with this issue,
https://github.com/vllm-project/vllm-ascend/issues/4020.
We modify the code according to
https://github.com/vllm-project/vllm-ascend/pull/3918.
And run below codes:
```python
# run with Qwen3-next-mtp
prompts = [
"Who are you?",
]
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=128)
llm = LLM(model="/home/model/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
enforce_eager=True,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": 1,
},
max_model_len=4096)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Outputs:
```text
Prompt: 'Who are you?', Generated text: ' I am Qwen, a large-scale language model independently developed by the Tongyi Lab under Alibaba Group. I am designed to answer questions, create text such as stories, official documents, emails, scripts, and more, as well as perform logical reasoning, programming, and other tasks. If you have any questions or need assistance, feel free to let me know anytime!'
```
Now, `torch_npu.npu_fused_infer_attention_score` is compatible with
Qwen3-Next.
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac
Signed-off-by: drslark <slarksblood@qq.com>
2025-11-06 22:00:24 +08:00
|
|
|
return [128]
|
2025-09-16 01:17:42 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
class AscendAttentionState(Enum):
|
2025-05-09 16:39:28 +08:00
|
|
|
PrefillNoCache = 0
|
|
|
|
|
PrefillCacheHit = 1
|
|
|
|
|
DecodeOnly = 2
|
|
|
|
|
ChunkedPrefill = 3
|
2025-06-09 22:21:42 +08:00
|
|
|
SpecDecoding = 4
|
2025-04-17 19:31:50 +08:00
|
|
|
|
|
|
|
|
|
2025-10-24 10:32:01 +08:00
|
|
|
@dataclass
|
|
|
|
|
class AscendMetadataForPrefill:
|
2025-11-11 09:18:02 +08:00
|
|
|
|
2025-12-06 09:33:28 +08:00
|
|
|
@dataclass
|
|
|
|
|
class AscendPCPMetadata:
|
|
|
|
|
q_head_idx: torch.Tensor = None
|
|
|
|
|
q_tail_idx: torch.Tensor = None
|
|
|
|
|
kv_with_q_head_nomask_idx: torch.Tensor = None
|
|
|
|
|
kv_with_q_head_mask_idx: torch.Tensor = None
|
|
|
|
|
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
|
|
|
|
kv_with_q_tail_mask_idx: torch.Tensor = None
|
|
|
|
|
attn_mask_seqlens: torch.Tensor = None
|
|
|
|
|
head_attn_nomask_seqlens: torch.Tensor = None
|
|
|
|
|
tail_attn_nomask_seqlens: torch.Tensor = None
|
|
|
|
|
q_full_idx: torch.Tensor = None
|
|
|
|
|
pcp_prefill_mask: torch.Tensor = None
|
|
|
|
|
|
2025-11-11 09:18:02 +08:00
|
|
|
@dataclass
|
|
|
|
|
class ChunkedContextMetadata:
|
2025-12-06 09:33:28 +08:00
|
|
|
actual_chunk_seq_lengths: torch.Tensor
|
|
|
|
|
actual_seq_lengths_kv: torch.Tensor
|
2025-11-14 08:43:37 +08:00
|
|
|
starts: torch.Tensor
|
2025-11-19 18:10:27 +08:00
|
|
|
chunk_seq_mask_filtered_indices: torch.Tensor
|
2025-11-14 08:43:37 +08:00
|
|
|
chunked_req_mask: Optional[list[bool]] = None
|
|
|
|
|
local_context_lens_allranks: Optional[list[list[int]]] = None
|
2025-11-11 09:18:02 +08:00
|
|
|
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
|
|
|
|
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
2025-11-19 18:10:27 +08:00
|
|
|
batch_chunk_seq_mask: Optional[list[bool]] = None
|
2025-11-11 09:18:02 +08:00
|
|
|
|
2025-10-24 10:32:01 +08:00
|
|
|
""" Prefill Specific Metadata for Ascend"""
|
|
|
|
|
pcp_metadata: Optional[AscendPCPMetadata] = None
|
|
|
|
|
pcp_allgather_restore_idx: Optional[List[int]] = None
|
2025-11-11 09:18:02 +08:00
|
|
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
|
|
|
|
block_tables: torch.Tensor = None
|
|
|
|
|
actual_seq_lengths_q: torch.Tensor = None
|
2025-10-24 10:32:01 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class AscendMetadataForDecode:
|
|
|
|
|
""" Decode Specific Metadata for Ascend"""
|
2025-11-06 14:58:24 +08:00
|
|
|
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
|
|
|
|
batch_seq_mask: torch.Tensor = None
|
2025-11-11 09:18:02 +08:00
|
|
|
block_tables: torch.Tensor = None
|
2025-10-24 10:32:01 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class AscendMetadata:
|
2025-08-14 09:32:41 +08:00
|
|
|
# **************************** Basic Properties ************************** #
|
2025-07-24 19:31:36 +08:00
|
|
|
attn_mask: Optional[torch.Tensor] = None
|
2025-04-17 19:31:50 +08:00
|
|
|
# Current state of this attention run.
|
|
|
|
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
2025-07-24 19:31:36 +08:00
|
|
|
|
|
|
|
|
# Number of tokens excluding padding.
|
2025-10-24 10:32:01 +08:00
|
|
|
num_actual_tokens_pcp_padded: int = 0
|
2025-07-24 19:31:36 +08:00
|
|
|
num_actual_tokens: int = 0
|
2025-10-24 10:32:01 +08:00
|
|
|
num_decode_tokens: int = 0
|
|
|
|
|
num_prefills: int = 0
|
|
|
|
|
num_decodes: int = 0
|
2025-07-24 19:31:36 +08:00
|
|
|
|
|
|
|
|
# The sequence length per sequence. Sequence length means the computed
|
|
|
|
|
# tokens + new tokens (is None if it is a decoding).
|
|
|
|
|
# (batch_size,)
|
2025-10-17 11:19:41 +08:00
|
|
|
# TODO(Angazenn): The following parameters are quite redundant and
|
|
|
|
|
# contains similar information (such as seq_lens seq_lens_list). We
|
|
|
|
|
# should simplified these parameters once attention schema in vLLM-Ascend
|
|
|
|
|
# is unified.
|
2025-07-24 19:31:36 +08:00
|
|
|
seq_lens: torch.Tensor = None
|
2025-10-17 11:19:41 +08:00
|
|
|
seq_lens_list: List[int] = None # type: ignore
|
|
|
|
|
actual_seq_lengths_q: List[int] = None # type: ignore
|
2025-11-17 10:50:35 +08:00
|
|
|
query_start_loc_list: List[int] = None # type: ignore
|
2025-07-24 19:31:36 +08:00
|
|
|
|
|
|
|
|
query_start_loc: torch.Tensor = None
|
|
|
|
|
query_lens: torch.Tensor = None
|
|
|
|
|
# Maximum query length in the batch (None for decoding).
|
|
|
|
|
max_query_len: Optional[int] = None
|
|
|
|
|
|
2025-08-14 09:32:41 +08:00
|
|
|
# ********************** KV Cache Related Properties ********************* #
|
2025-07-24 19:31:36 +08:00
|
|
|
# Block addresses per sequence (Seq id -> list of physical block).
|
|
|
|
|
# (batch_size, max_blocks_per_seq)
|
|
|
|
|
block_tables: torch.Tensor = None
|
|
|
|
|
|
|
|
|
|
# The indices of the token slots that input tokens will be stored into.
|
|
|
|
|
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
|
|
|
|
|
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
|
|
|
|
|
# and 1st slot in block 1, respectively.
|
|
|
|
|
# (num_tokens,)
|
|
|
|
|
slot_mapping: torch.Tensor = None
|
2025-12-06 09:33:28 +08:00
|
|
|
# pcp
|
2025-10-24 10:32:01 +08:00
|
|
|
prefill: Optional[AscendMetadataForPrefill] = None
|
2025-12-06 09:33:28 +08:00
|
|
|
# dcp
|
2025-10-24 10:32:01 +08:00
|
|
|
decode_meta: Optional[AscendMetadataForDecode] = None
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
class AscendAttentionMetadataBuilder:
|
2025-09-23 11:30:31 +08:00
|
|
|
# Does this backend/builder support ACL Graphs for attention (default: no).
|
|
|
|
|
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
2025-11-17 10:50:35 +08:00
|
|
|
AttentionCGSupport.ALWAYS
|
|
|
|
|
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
2025-09-22 17:14:28 +08:00
|
|
|
# Does this backend/builder reorder the batch?
|
|
|
|
|
# If not, set this to None. Otherwise set it to the query
|
|
|
|
|
# length that will be pulled into the front of the batch.
|
2025-09-16 01:17:42 +08:00
|
|
|
reorder_batch_threshold: ClassVar[int] = 1
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
2025-09-16 01:17:42 +08:00
|
|
|
kv_cache_spec: AttentionSpec,
|
|
|
|
|
layer_names: list[str],
|
2025-08-20 09:01:04 +08:00
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
):
|
|
|
|
|
self.vllm_config = vllm_config
|
|
|
|
|
self.model_config = vllm_config.model_config
|
2025-11-06 14:58:24 +08:00
|
|
|
self.compilation_config = vllm_config.compilation_config
|
2025-08-20 09:01:04 +08:00
|
|
|
self.device = device
|
2025-09-16 01:17:42 +08:00
|
|
|
self.max_num_blocks_per_req = cdiv(
|
|
|
|
|
self.model_config.max_model_len,
|
|
|
|
|
AscendAttentionBackend.get_supported_block_size()[0])
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-11-07 16:39:03 +08:00
|
|
|
self.speculative_config = vllm_config.speculative_config
|
|
|
|
|
self.decode_threshold = 1
|
|
|
|
|
if self.speculative_config:
|
|
|
|
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
|
|
|
self.decode_threshold += spec_token_num
|
|
|
|
|
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
|
|
|
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
|
|
|
|
got {self.decode_threshold}"
|
|
|
|
|
|
|
|
|
|
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
|
|
|
|
|
2025-11-11 09:18:02 +08:00
|
|
|
scheduler_config = vllm_config.scheduler_config
|
2025-12-02 22:10:52 +08:00
|
|
|
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
2025-11-11 09:18:02 +08:00
|
|
|
|
2025-09-16 01:17:42 +08:00
|
|
|
def reorder_batch(self, input_batch,
|
2025-04-19 17:38:18 +08:00
|
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
|
|
|
return False
|
|
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
def build(
|
|
|
|
|
self,
|
2025-09-16 01:17:42 +08:00
|
|
|
common_prefix_len: int,
|
2025-08-20 09:01:04 +08:00
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
2025-09-22 17:14:28 +08:00
|
|
|
model: Optional[nn.Module] = None,
|
2025-08-20 09:01:04 +08:00
|
|
|
):
|
|
|
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
|
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
|
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
|
|
|
|
num_reqs
|
|
|
|
|
+ 1]
|
2025-10-24 10:32:01 +08:00
|
|
|
|
|
|
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
2025-11-11 09:18:02 +08:00
|
|
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
2025-10-24 10:32:01 +08:00
|
|
|
assert num_decodes + num_prefills == num_reqs
|
|
|
|
|
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
|
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
block_table = common_attn_metadata.block_table_tensor
|
|
|
|
|
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
|
|
|
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
2025-10-24 10:32:01 +08:00
|
|
|
|
|
|
|
|
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
|
|
|
|
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
|
|
|
|
if num_actual_tokens_pcp_padded is None:
|
|
|
|
|
num_actual_tokens_pcp_padded = num_actual_tokens
|
|
|
|
|
|
|
|
|
|
slot_mapping = common_attn_metadata.slot_mapping[:
|
|
|
|
|
num_actual_tokens_pcp_padded]
|
|
|
|
|
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
2025-08-20 09:01:04 +08:00
|
|
|
attn_mask = common_attn_metadata.attn_mask
|
|
|
|
|
attn_state = common_attn_metadata.attn_state
|
|
|
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
|
|
|
|
num_reqs
|
|
|
|
|
+ 1]
|
2025-12-03 17:33:31 +08:00
|
|
|
if common_attn_metadata.num_input_tokens > num_actual_tokens:
|
[Core]Append padding logic for Attention (#3256)
### What this PR does / why we need it?
This PR aims to add padding logic to seq_lens、block_tables when running
in full decode scenario. Before this PR, the number of input tokens with
padding might exceeds corresponding seq_lens. For example, when running
in full decode scenario:
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1]
query_start_loc: [0, 1, 2]
```
Here, `input_ids` is padded by 2 tokens while
`seq_lens`/`query_start_loc` are not. The mismatch between `input_ids`
and `seq_lens`/`query_start_loc` might cause some potential bugs. This
PR would change it into :
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1, 1, 1]
query_start_loc: [0, 1, 2, 3, 4]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: Angazenn <supperccell@163.com>
2025-10-17 21:56:01 +08:00
|
|
|
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
|
|
|
|
|
seq_lens = torch.cat([
|
|
|
|
|
seq_lens,
|
2025-12-03 17:33:31 +08:00
|
|
|
torch.tensor([padded_num_tokens
|
|
|
|
|
]).to(seq_lens.device).to(seq_lens.dtype)
|
[Core]Append padding logic for Attention (#3256)
### What this PR does / why we need it?
This PR aims to add padding logic to seq_lens、block_tables when running
in full decode scenario. Before this PR, the number of input tokens with
padding might exceeds corresponding seq_lens. For example, when running
in full decode scenario:
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1]
query_start_loc: [0, 1, 2]
```
Here, `input_ids` is padded by 2 tokens while
`seq_lens`/`query_start_loc` are not. The mismatch between `input_ids`
and `seq_lens`/`query_start_loc` might cause some potential bugs. This
PR would change it into :
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1, 1, 1]
query_start_loc: [0, 1, 2, 3, 4]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: Angazenn <supperccell@163.com>
2025-10-17 21:56:01 +08:00
|
|
|
])
|
|
|
|
|
block_table_padding = torch.zeros(
|
|
|
|
|
(padded_num_tokens, ) + block_table.shape[1:],
|
|
|
|
|
dtype=block_table.dtype,
|
|
|
|
|
device=block_table.device)
|
|
|
|
|
block_table = torch.cat([block_table, block_table_padding], dim=0)
|
|
|
|
|
query_start_loc_cpu = torch.cat([
|
|
|
|
|
query_start_loc_cpu,
|
2025-12-03 17:33:31 +08:00
|
|
|
torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
|
|
|
|
|
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
|
[Core]Append padding logic for Attention (#3256)
### What this PR does / why we need it?
This PR aims to add padding logic to seq_lens、block_tables when running
in full decode scenario. Before this PR, the number of input tokens with
padding might exceeds corresponding seq_lens. For example, when running
in full decode scenario:
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1]
query_start_loc: [0, 1, 2]
```
Here, `input_ids` is padded by 2 tokens while
`seq_lens`/`query_start_loc` are not. The mismatch between `input_ids`
and `seq_lens`/`query_start_loc` might cause some potential bugs. This
PR would change it into :
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1, 1, 1]
query_start_loc: [0, 1, 2, 3, 4]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: Angazenn <supperccell@163.com>
2025-10-17 21:56:01 +08:00
|
|
|
])
|
|
|
|
|
|
2025-12-06 17:15:57 +08:00
|
|
|
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
|
|
|
|
self.device, non_blocking=True)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-08-01 09:08:45 +08:00
|
|
|
attn_metadata = AscendMetadata(
|
|
|
|
|
num_actual_tokens=num_actual_tokens,
|
2025-10-24 10:32:01 +08:00
|
|
|
num_decode_tokens=num_decode_tokens,
|
|
|
|
|
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
2025-08-01 09:08:45 +08:00
|
|
|
block_tables=block_table,
|
|
|
|
|
query_start_loc=query_start_loc,
|
2025-11-17 10:50:35 +08:00
|
|
|
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
|
2025-08-01 09:08:45 +08:00
|
|
|
query_lens=query_lens,
|
|
|
|
|
seq_lens=seq_lens,
|
2025-10-17 11:19:41 +08:00
|
|
|
seq_lens_list=seq_lens.tolist(),
|
2025-08-20 09:01:04 +08:00
|
|
|
max_query_len=common_attn_metadata.max_query_len,
|
2025-10-17 11:19:41 +08:00
|
|
|
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
2025-08-01 09:08:45 +08:00
|
|
|
slot_mapping=slot_mapping,
|
|
|
|
|
attn_mask=attn_mask,
|
|
|
|
|
attn_state=attn_state,
|
2025-10-24 10:32:01 +08:00
|
|
|
num_prefills=num_prefills,
|
2025-12-06 09:33:28 +08:00
|
|
|
num_decodes=num_decodes)
|
2025-04-19 17:38:18 +08:00
|
|
|
return attn_metadata
|
|
|
|
|
|
2025-09-22 17:14:28 +08:00
|
|
|
def build_for_graph_capture(
|
|
|
|
|
self,
|
|
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
|
|
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
2025-10-10 16:31:20 +08:00
|
|
|
model: Optional[nn.Module] = None,
|
2025-09-22 17:14:28 +08:00
|
|
|
):
|
|
|
|
|
if attn_state == AscendAttentionState.DecodeOnly:
|
|
|
|
|
attn_metadata = self.build(
|
|
|
|
|
common_prefix_len=0,
|
|
|
|
|
common_attn_metadata=common_attn_metadata,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Currently we only support building dummy metadata for DecodeOnly state"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attn_metadata.attn_state = attn_state
|
|
|
|
|
return attn_metadata
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
class AscendAttentionBackendImpl(AttentionImpl):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
scale: float,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
alibi_slopes: Optional[List[float]],
|
|
|
|
|
sliding_window: Optional[int],
|
|
|
|
|
kv_cache_dtype: str,
|
2025-07-24 10:23:34 +08:00
|
|
|
logits_soft_cap: Optional[float],
|
|
|
|
|
attn_type: str,
|
|
|
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
|
|
|
**kwargs,
|
2025-03-20 19:34:44 +08:00
|
|
|
) -> None:
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.head_size = head_size
|
|
|
|
|
self.scale = float(scale)
|
|
|
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
|
|
|
self.hidden_size = self.num_heads * self.head_size
|
|
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
|
self.sliding_window = sliding_window
|
|
|
|
|
if alibi_slopes is not None:
|
|
|
|
|
alibi_slopes = torch.tensor(alibi_slopes,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device="npu")
|
|
|
|
|
self.alibi_slopes = alibi_slopes
|
|
|
|
|
self.attn_type = attn_type
|
|
|
|
|
|
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
|
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
2025-04-17 19:31:50 +08:00
|
|
|
self.key_cache = None
|
|
|
|
|
self.value_cache = None
|
2025-12-06 09:33:28 +08:00
|
|
|
|
|
|
|
|
def full_graph_attention(self, query: torch.Tensor, key: torch.Tensor,
|
2025-11-17 10:50:35 +08:00
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
2025-12-06 09:33:28 +08:00
|
|
|
output: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
2025-11-17 10:50:35 +08:00
|
|
|
block_size = 128
|
|
|
|
|
block_table = None
|
|
|
|
|
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
|
|
|
|
elif attn_metadata.attn_state == \
|
|
|
|
|
AscendAttentionState.PrefillCacheHit:
|
|
|
|
|
batch_size = attn_metadata.query_lens.shape[0]
|
|
|
|
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
|
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
|
|
|
key = self.key_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
value = self.value_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
|
|
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
|
|
|
key = self.key_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
value = self.value_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
block_table = attn_metadata.block_tables
|
|
|
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
|
|
|
# Normal V1 situation.
|
|
|
|
|
else:
|
|
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
|
|
|
key = self.key_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
value = self.value_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
block_table = attn_metadata.block_tables
|
|
|
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
|
|
|
|
|
|
|
|
num_tokens = attn_metadata.query_start_loc_list[-1]
|
|
|
|
|
graph_params = get_graph_params()
|
|
|
|
|
query_start_loc = attn_metadata.query_start_loc_list
|
|
|
|
|
# Prepare tensors for attention output
|
|
|
|
|
# TODO: Refactor this to step-level instead of layer-level
|
|
|
|
|
|
|
|
|
|
# Get workspace from cache or calculate it if not present.
|
|
|
|
|
workspace = graph_params.workspaces.get(num_tokens)
|
|
|
|
|
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
|
|
|
|
|
if workspace is None:
|
|
|
|
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
2025-12-03 17:33:31 +08:00
|
|
|
atten_mask=attn_metadata.attn_mask,
|
2025-11-17 10:50:35 +08:00
|
|
|
block_table=block_table,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
actual_seq_lengths=query_start_loc,
|
|
|
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
)
|
|
|
|
|
update_graph_params_workspaces(num_tokens, workspace)
|
|
|
|
|
|
|
|
|
|
# Handle graph capturing mode
|
|
|
|
|
stream = torch_npu.npu.current_stream()
|
|
|
|
|
|
|
|
|
|
event = torch.npu.ExternalEvent()
|
|
|
|
|
event.wait(stream)
|
|
|
|
|
event.reset(stream)
|
|
|
|
|
graph_params.events[num_tokens].append(event)
|
|
|
|
|
graph_params.attn_params[num_tokens].append(
|
|
|
|
|
(weak_ref_tensors(query), weak_ref_tensors(key),
|
|
|
|
|
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
2025-12-03 17:33:31 +08:00
|
|
|
weak_ref_tensors(attn_metadata.attn_mask), block_size,
|
2025-11-17 10:50:35 +08:00
|
|
|
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
|
|
|
|
self.num_heads, self.scale, weak_ref_tensors(output),
|
|
|
|
|
weak_ref_tensors(softmax_lse)))
|
|
|
|
|
|
|
|
|
|
torch.npu.graph_task_group_begin(stream)
|
|
|
|
|
torch_npu.npu_fused_infer_attention_score.out(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
2025-12-03 17:33:31 +08:00
|
|
|
atten_mask=attn_metadata.attn_mask,
|
2025-11-17 10:50:35 +08:00
|
|
|
block_table=block_table,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
actual_seq_lengths=query_start_loc,
|
|
|
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
workspace=workspace,
|
|
|
|
|
out=[output, softmax_lse],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
output = output.view(num_tokens, self.num_heads, self.head_size)
|
|
|
|
|
|
|
|
|
|
handle = torch.npu.graph_task_group_end(stream)
|
|
|
|
|
graph_params.handles[num_tokens].append(handle)
|
|
|
|
|
return output, num_tokens
|
|
|
|
|
|
2025-12-02 09:13:26 +08:00
|
|
|
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor, attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor):
|
|
|
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
|
|
|
block_size = 128
|
|
|
|
|
block_table = None
|
|
|
|
|
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
|
|
|
|
elif attn_metadata.attn_state == \
|
|
|
|
|
AscendAttentionState.PrefillCacheHit:
|
|
|
|
|
batch_size = attn_metadata.query_lens.shape[0]
|
|
|
|
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
|
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
2025-10-27 19:41:07 +08:00
|
|
|
key = self.key_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
value = self.value_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
2025-12-02 09:13:26 +08:00
|
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
|
|
|
# chunked_prefill.
|
2025-10-27 19:41:07 +08:00
|
|
|
else:
|
2025-12-02 09:13:26 +08:00
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
|
|
|
key = self.key_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
value = self.value_cache.view( # type: ignore
|
|
|
|
|
num_block, block_size, -1)
|
|
|
|
|
block_table = attn_metadata.block_tables
|
|
|
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
|
|
|
|
|
|
|
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
|
|
|
|
query = query[:num_tokens]
|
|
|
|
|
# Prepare tensors for attention output
|
|
|
|
|
# TODO: Refactor this to step-level instead of layer-level
|
|
|
|
|
|
|
|
|
|
# Get workspace from cache or calculate it if not present.
|
|
|
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
atten_mask=attn_metadata.attn_mask,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
|
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attn_output = attn_output.view(num_tokens, self.num_heads,
|
|
|
|
|
self.head_size)
|
|
|
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
2025-08-14 09:32:41 +08:00
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def _forward_decode_only(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
2025-09-19 22:35:14 +08:00
|
|
|
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
|
|
|
|
0] == query.size(0):
|
2025-08-28 10:37:19 +08:00
|
|
|
batch_size = attn_metadata.seq_lens.shape[0]
|
|
|
|
|
block_size = 128
|
|
|
|
|
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
|
|
|
|
key = self.key_cache
|
|
|
|
|
value = self.value_cache
|
|
|
|
|
if self.key_cache is not None and self.value_cache is not None:
|
|
|
|
|
block_size = self.key_cache.shape[1]
|
|
|
|
|
key = self.key_cache.flatten(2, 3).contiguous()
|
|
|
|
|
value = self.value_cache.flatten(2, 3).contiguous()
|
|
|
|
|
|
|
|
|
|
output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
input_layout="BSH",
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
pre_tokens=self.sliding_window,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
|
|
|
|
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
|
|
|
|
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
|
|
|
|
|
|
|
|
|
output = output.view(batch_size, self.num_heads, self.head_size)
|
|
|
|
|
else:
|
2025-11-17 10:50:35 +08:00
|
|
|
torch_npu._npu_paged_attention(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
|
|
|
|
context_lens=attn_metadata.seq_lens,
|
|
|
|
|
out=output)
|
2025-08-14 09:32:41 +08:00
|
|
|
return output
|
|
|
|
|
|
2025-12-02 09:13:26 +08:00
|
|
|
def _forward_encode(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
|
|
|
|
output = torch_npu.npu_fusion_attention(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
head_num=self.num_heads,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
sparse_mode=4,
|
|
|
|
|
atten_mask=attn_metadata.attn_mask,
|
|
|
|
|
pre_tockens=attn_metadata.max_query_len,
|
|
|
|
|
next_tockens=attn_metadata.max_query_len,
|
|
|
|
|
actual_seq_qlen=cum_seq_len,
|
|
|
|
|
actual_seq_kvlen=cum_seq_len,
|
|
|
|
|
)[0]
|
|
|
|
|
return output
|
|
|
|
|
|
2025-12-06 09:33:28 +08:00
|
|
|
def reshape_and_cache(
|
|
|
|
|
self,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
kv_cache: Tuple[torch.Tensor],
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
if len(kv_cache) > 1:
|
|
|
|
|
if self.key_cache is None:
|
|
|
|
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
|
|
|
|
slots = attn_metadata.slot_mapping
|
|
|
|
|
torch_npu._npu_reshape_and_cache(
|
|
|
|
|
key=key[:attn_metadata.num_actual_tokens],
|
|
|
|
|
value=value[:attn_metadata.num_actual_tokens],
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
slot_indices=slots)
|
|
|
|
|
return key, value
|
|
|
|
|
|
|
|
|
|
def forward_impl(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
kv_cache: Tuple[torch.Tensor],
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
):
|
|
|
|
|
forward_context: ForwardContext = get_forward_context()
|
|
|
|
|
if not forward_context.capturing:
|
|
|
|
|
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
|
|
|
output = self._forward_decode_only(query, attn_metadata,
|
|
|
|
|
output)
|
|
|
|
|
else:
|
|
|
|
|
output = self._forward_prefill(query, key, value,
|
|
|
|
|
attn_metadata, output)
|
|
|
|
|
else:
|
|
|
|
|
attn_output, num_tokens = self.full_graph_attention(
|
|
|
|
|
query, key, value, attn_metadata, output)
|
|
|
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
2025-10-25 08:58:35 +08:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
kv_cache: Tuple[torch.Tensor],
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
|
|
|
|
output_scale: Optional[torch.Tensor] = None,
|
|
|
|
|
output_block_scale: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Forward pass with Ascend attention.
|
|
|
|
|
Args:
|
|
|
|
|
query: shape = [num_tokens, num_heads, head_size]
|
|
|
|
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
|
|
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
|
|
|
kv_cache: shape =
|
|
|
|
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
|
|
|
|
attn_metadata: Metadata for attention.
|
|
|
|
|
Returns:
|
|
|
|
|
shape = [num_tokens, num_heads * head_size]
|
|
|
|
|
"""
|
|
|
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
|
|
|
|
|
|
if output_scale is not None or output_block_scale is not None:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"fused output quantization is not yet supported"
|
|
|
|
|
" for AscendAttentionBackendImpl")
|
|
|
|
|
|
|
|
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
2025-12-02 09:13:26 +08:00
|
|
|
if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY:
|
2025-10-25 08:58:35 +08:00
|
|
|
raise NotImplementedError("Encoder/decoder cross-attention "
|
|
|
|
|
"are not implemented for "
|
|
|
|
|
"PallasAttentionBackendImpl")
|
2025-12-02 09:13:26 +08:00
|
|
|
num_tokens = query.shape[0]
|
|
|
|
|
if attn_metadata is None:
|
|
|
|
|
return output.fill_(0)
|
2025-12-06 09:33:28 +08:00
|
|
|
key, value = self.reshape_and_cache(key, value, kv_cache,
|
|
|
|
|
attn_metadata)
|
|
|
|
|
if self.attn_type == AttentionType.ENCODER_ONLY:
|
|
|
|
|
attn_output = self._forward_encode(query, key, value,
|
2025-12-02 09:13:26 +08:00
|
|
|
attn_metadata, output)
|
|
|
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
2025-12-06 09:33:28 +08:00
|
|
|
return output
|
|
|
|
|
output = self.forward_impl(query, key, value, kv_cache, attn_metadata,
|
|
|
|
|
output)
|
2025-10-25 08:58:35 +08:00
|
|
|
return output
|