2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
2025-04-17 19:31:50 +08:00
|
|
|
from enum import Enum
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch_npu
|
2025-12-29 15:28:34 +08:00
|
|
|
import vllm.envs as envs_vllm
|
2025-12-06 09:33:28 +08:00
|
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
[v0.18.0]feat(quant): add C8 INT8 KV cache support for GQA attention models (#7474) (#8007)
backport of #7474
This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.
**Key changes:**
1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.
2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.
3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.
4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.
5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.
Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.
No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.
Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):
```
============ Serving Benchmark Result ============
Successful requests: 960
Failed requests: 0
Maximum request concurrency: 192
Benchmark duration (s): 1359.81
Total input tokens: 9830400
Total generated tokens: 1966080
Request throughput (req/s): 0.71
Output token throughput (tok/s): 1445.85
Peak output token throughput (tok/s): 2304.00
Total token throughput (tok/s): 8675.12
---------------Time to First Token----------------
Mean TTFT (ms): 24598.51
Median TTFT (ms): 23167.02
P50 TTFT (ms): 23167.02
P90 TTFT (ms): 47717.08
P99 TTFT (ms): 84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 120.76
Median TPOT (ms): 121.50
P50 TPOT (ms): 121.50
P90 TPOT (ms): 127.05
P99 TPOT (ms): 130.13
---------------Inter-token Latency----------------
Mean ITL (ms): 120.70
Median ITL (ms): 90.34
P50 ITL (ms): 90.34
P90 ITL (ms): 93.79
P99 ITL (ms): 101.80
==================================================
```
All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.
- vLLM version: v0.17.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c
Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
2026-04-08 10:51:58 +08:00
|
|
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
2025-11-24 17:08:20 +08:00
|
|
|
from vllm.utils.math_utils import cdiv
|
2026-01-23 09:45:08 +08:00
|
|
|
from vllm.v1.attention.backend import ( # type: ignore
|
|
|
|
|
AttentionBackend,
|
|
|
|
|
AttentionCGSupport,
|
|
|
|
|
AttentionImpl,
|
|
|
|
|
AttentionLayer,
|
|
|
|
|
AttentionMetadataBuilder,
|
|
|
|
|
AttentionType,
|
|
|
|
|
)
|
|
|
|
|
from vllm.v1.attention.backends.registry import ( # type: ignore
|
|
|
|
|
AttentionBackendEnum,
|
|
|
|
|
register_backend,
|
|
|
|
|
)
|
2025-04-19 17:38:18 +08:00
|
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
2026-01-11 11:38:45 +08:00
|
|
|
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2026-03-13 09:11:46 +08:00
|
|
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
2026-01-07 17:09:52 +08:00
|
|
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
2026-01-19 08:59:46 +08:00
|
|
|
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
|
|
|
|
|
from vllm_ascend.attention.utils import (
|
|
|
|
|
AscendCommonAttentionMetadata,
|
|
|
|
|
enable_cp,
|
|
|
|
|
split_decodes_and_prefills,
|
|
|
|
|
using_paged_attention,
|
|
|
|
|
)
|
2025-12-29 09:54:51 +08:00
|
|
|
from vllm_ascend.compilation.acl_graph import (
|
2026-01-19 08:59:46 +08:00
|
|
|
get_draft_graph_params,
|
|
|
|
|
get_graph_params,
|
|
|
|
|
update_draft_graph_params_workspaces,
|
|
|
|
|
update_graph_params_workspaces,
|
|
|
|
|
)
|
2026-01-13 09:53:26 +08:00
|
|
|
from vllm_ascend.device.device_op import DeviceOperator
|
2026-01-10 22:57:57 +08:00
|
|
|
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
2026-01-23 09:45:08 +08:00
|
|
|
from vllm_ascend.utils import weak_ref_tensors
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-12-29 14:56:25 +08:00
|
|
|
# default max value of sliding window size
|
|
|
|
|
SWA_INT_MAX = 2147483647
|
|
|
|
|
|
2025-11-26 11:48:58 +08:00
|
|
|
|
|
|
|
|
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
2025-03-20 19:34:44 +08:00
|
|
|
class AscendAttentionBackend(AttentionBackend):
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
accept_output_buffer: bool = True
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_name() -> str:
|
2025-12-29 15:28:34 +08:00
|
|
|
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
|
|
|
|
|
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
|
|
|
|
|
# rectify this when vllm disable the assertion.
|
|
|
|
|
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2026-01-19 08:59:46 +08:00
|
|
|
def get_impl_cls() -> type["AscendAttentionBackendImpl"]:
|
2025-12-19 14:57:09 +08:00
|
|
|
if enable_cp():
|
2026-01-19 08:59:46 +08:00
|
|
|
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl
|
|
|
|
|
|
2025-12-06 09:33:28 +08:00
|
|
|
return AscendAttentionCPImpl
|
2025-03-20 19:34:44 +08:00
|
|
|
return AscendAttentionBackendImpl
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
2025-12-19 14:57:09 +08:00
|
|
|
if enable_cp():
|
2026-01-19 08:59:46 +08:00
|
|
|
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPMetadataBuilder
|
|
|
|
|
|
2025-12-06 09:33:28 +08:00
|
|
|
return AscendAttentionCPMetadataBuilder
|
2025-04-19 17:38:18 +08:00
|
|
|
return AscendAttentionMetadataBuilder
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_kv_cache_shape(
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
head_size: int,
|
2026-01-19 08:59:46 +08:00
|
|
|
) -> tuple[int, ...]:
|
2025-04-17 19:31:50 +08:00
|
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def swap_blocks(
|
2026-01-19 08:59:46 +08:00
|
|
|
src_kv_cache: list[torch.Tensor],
|
|
|
|
|
dst_kv_cache: list[torch.Tensor],
|
2025-03-20 19:34:44 +08:00
|
|
|
src_to_dst: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
|
|
|
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
|
|
|
src_indices = src_to_dst[:, 0]
|
|
|
|
|
dst_indices = src_to_dst[:, 1]
|
|
|
|
|
|
2026-01-19 08:59:46 +08:00
|
|
|
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(dst_key_cache.device)
|
|
|
|
|
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(dst_key_cache.device)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def copy_blocks(
|
2026-01-19 08:59:46 +08:00
|
|
|
kv_caches: list[torch.Tensor],
|
2025-03-20 19:34:44 +08:00
|
|
|
src_to_dists: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
src_indices = src_to_dists[:, 0]
|
|
|
|
|
dst_indices = src_to_dists[:, 1]
|
|
|
|
|
|
|
|
|
|
for kv_cache in kv_caches:
|
|
|
|
|
key_caches = kv_cache[0]
|
|
|
|
|
value_caches = kv_cache[1]
|
|
|
|
|
key_caches[dst_indices] = key_caches[src_indices]
|
|
|
|
|
value_caches[dst_indices] = value_caches[src_indices]
|
|
|
|
|
|
2025-09-16 01:17:42 +08:00
|
|
|
@staticmethod
|
2026-02-02 19:16:26 +08:00
|
|
|
def get_supported_kernel_block_sizes() -> list[int]:
|
[BugFix][main] Adapted to torch_npu.npu_fused_infer_attention_score (#4025)
### What this PR does / why we need it?
Fixes a compatible bug with `torch_npu.npu_fused_infer_attention_score`
which is discribed in
https://github.com/vllm-project/vllm-ascend/issues/4020.
@momo609 tells us this solution.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
The environment is same with this issue,
https://github.com/vllm-project/vllm-ascend/issues/4020.
We modify the code according to
https://github.com/vllm-project/vllm-ascend/pull/3918.
And run below codes:
```python
# run with Qwen3-next-mtp
prompts = [
"Who are you?",
]
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=128)
llm = LLM(model="/home/model/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
enforce_eager=True,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": 1,
},
max_model_len=4096)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Outputs:
```text
Prompt: 'Who are you?', Generated text: ' I am Qwen, a large-scale language model independently developed by the Tongyi Lab under Alibaba Group. I am designed to answer questions, create text such as stories, official documents, emails, scripts, and more, as well as perform logical reasoning, programming, and other tasks. If you have any questions or need assistance, feel free to let me know anytime!'
```
Now, `torch_npu.npu_fused_infer_attention_score` is compatible with
Qwen3-Next.
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac
Signed-off-by: drslark <slarksblood@qq.com>
2025-11-06 22:00:24 +08:00
|
|
|
return [128]
|
2025-09-16 01:17:42 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
class AscendAttentionState(Enum):
|
2025-05-09 16:39:28 +08:00
|
|
|
PrefillNoCache = 0
|
|
|
|
|
PrefillCacheHit = 1
|
|
|
|
|
DecodeOnly = 2
|
|
|
|
|
ChunkedPrefill = 3
|
2025-06-09 22:21:42 +08:00
|
|
|
SpecDecoding = 4
|
2025-04-17 19:31:50 +08:00
|
|
|
|
|
|
|
|
|
2025-10-24 10:32:01 +08:00
|
|
|
@dataclass
|
|
|
|
|
class AscendMetadata:
|
2026-01-13 08:46:50 +08:00
|
|
|
"""
|
|
|
|
|
Per-layer attention metadata for Ascend FlashAttention backend.
|
|
|
|
|
|
|
|
|
|
Contains attention masks, token counts, sequence lengths and KV cache
|
|
|
|
|
related properties for attention computation.
|
|
|
|
|
"""
|
2026-01-19 08:59:46 +08:00
|
|
|
|
2025-08-14 09:32:41 +08:00
|
|
|
# **************************** Basic Properties ************************** #
|
2026-01-19 08:59:46 +08:00
|
|
|
attn_mask: torch.Tensor | None = None
|
2025-04-17 19:31:50 +08:00
|
|
|
# Current state of this attention run.
|
|
|
|
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
2025-07-24 19:31:36 +08:00
|
|
|
|
|
|
|
|
# Number of tokens excluding padding.
|
2025-10-24 10:32:01 +08:00
|
|
|
num_actual_tokens_pcp_padded: int = 0
|
2025-07-24 19:31:36 +08:00
|
|
|
num_actual_tokens: int = 0
|
2025-10-24 10:32:01 +08:00
|
|
|
num_decode_tokens: int = 0
|
|
|
|
|
num_prefills: int = 0
|
|
|
|
|
num_decodes: int = 0
|
2026-02-27 16:06:56 +08:00
|
|
|
num_decodes_flatten: int = 0
|
2025-07-24 19:31:36 +08:00
|
|
|
|
|
|
|
|
# The sequence length per sequence. Sequence length means the computed
|
|
|
|
|
# tokens + new tokens (is None if it is a decoding).
|
|
|
|
|
# (batch_size,)
|
2025-10-17 11:19:41 +08:00
|
|
|
# TODO(Angazenn): The following parameters are quite redundant and
|
|
|
|
|
# contains similar information (such as seq_lens seq_lens_list). We
|
|
|
|
|
# should simplified these parameters once attention schema in vLLM-Ascend
|
|
|
|
|
# is unified.
|
2025-07-24 19:31:36 +08:00
|
|
|
seq_lens: torch.Tensor = None
|
2026-03-27 14:24:53 +08:00
|
|
|
seq_lens_cpu: torch.Tensor = None
|
2026-01-19 08:59:46 +08:00
|
|
|
seq_lens_list: list[int] = None # type: ignore
|
|
|
|
|
actual_seq_lengths_q: list[int] = None # type: ignore
|
2025-07-24 19:31:36 +08:00
|
|
|
|
|
|
|
|
query_start_loc: torch.Tensor = None
|
|
|
|
|
# Maximum query length in the batch (None for decoding).
|
2026-01-19 08:59:46 +08:00
|
|
|
max_query_len: int | None = None
|
2025-07-24 19:31:36 +08:00
|
|
|
|
2025-08-14 09:32:41 +08:00
|
|
|
# ********************** KV Cache Related Properties ********************* #
|
2025-07-24 19:31:36 +08:00
|
|
|
# Block addresses per sequence (Seq id -> list of physical block).
|
|
|
|
|
# (batch_size, max_blocks_per_seq)
|
|
|
|
|
block_tables: torch.Tensor = None
|
|
|
|
|
|
|
|
|
|
# The indices of the token slots that input tokens will be stored into.
|
|
|
|
|
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
|
|
|
|
|
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
|
|
|
|
|
# and 1st slot in block 1, respectively.
|
|
|
|
|
# (num_tokens,)
|
|
|
|
|
slot_mapping: torch.Tensor = None
|
2025-12-06 09:33:28 +08:00
|
|
|
# pcp
|
2026-01-19 08:59:46 +08:00
|
|
|
prefill: AscendMetadataForPrefill | None = None
|
2025-12-06 09:33:28 +08:00
|
|
|
# dcp
|
2026-01-19 08:59:46 +08:00
|
|
|
decode_meta: AscendMetadataForDecode | None = None
|
2025-10-24 10:32:01 +08:00
|
|
|
|
2025-12-19 14:57:09 +08:00
|
|
|
causal: bool = True
|
|
|
|
|
# runner_type in model_config.
|
|
|
|
|
model_runner_type: str = ""
|
2025-12-31 15:09:01 +08:00
|
|
|
# prefill reshape_and_cache event
|
|
|
|
|
reshape_cache_event: torch.npu.Event = None
|
2025-12-10 11:37:57 +08:00
|
|
|
|
2025-12-29 14:56:25 +08:00
|
|
|
# sliding window attention mask
|
2026-01-19 08:59:46 +08:00
|
|
|
swa_mask: torch.Tensor | None = None
|
2025-12-29 14:56:25 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
2026-01-13 08:46:50 +08:00
|
|
|
"""
|
|
|
|
|
Builder for constructing AscendMetadata from CommonAttentionMetadata.
|
|
|
|
|
|
|
|
|
|
Handles attention mask generation and metadata preparation for
|
|
|
|
|
Ascend FlashAttention backend.
|
|
|
|
|
"""
|
2026-01-19 08:59:46 +08:00
|
|
|
|
2025-09-22 17:14:28 +08:00
|
|
|
# Does this backend/builder reorder the batch?
|
|
|
|
|
# If not, set this to None. Otherwise set it to the query
|
|
|
|
|
# length that will be pulled into the front of the batch.
|
2026-03-27 11:20:59 +08:00
|
|
|
reorder_batch_threshold: int = 1
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
2025-09-16 01:17:42 +08:00
|
|
|
kv_cache_spec: AttentionSpec,
|
|
|
|
|
layer_names: list[str],
|
2025-08-20 09:01:04 +08:00
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
):
|
2026-01-06 08:44:29 +08:00
|
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
2025-08-20 09:01:04 +08:00
|
|
|
self.vllm_config = vllm_config
|
|
|
|
|
self.model_config = vllm_config.model_config
|
2025-11-06 14:58:24 +08:00
|
|
|
self.compilation_config = vllm_config.compilation_config
|
2025-08-20 09:01:04 +08:00
|
|
|
self.device = device
|
2025-09-16 01:17:42 +08:00
|
|
|
self.max_num_blocks_per_req = cdiv(
|
2026-02-02 19:16:26 +08:00
|
|
|
self.model_config.max_model_len, AscendAttentionBackend.get_supported_kernel_block_sizes()[0]
|
2026-01-19 08:59:46 +08:00
|
|
|
)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-11-07 16:39:03 +08:00
|
|
|
self.speculative_config = vllm_config.speculative_config
|
|
|
|
|
self.decode_threshold = 1
|
|
|
|
|
if self.speculative_config:
|
|
|
|
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
|
|
|
self.decode_threshold += spec_token_num
|
2026-01-19 08:59:46 +08:00
|
|
|
assert self.decode_threshold <= 16, (
|
|
|
|
|
f"decode_threshold exceeded \
|
2025-11-07 16:39:03 +08:00
|
|
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
|
|
|
|
got {self.decode_threshold}"
|
2026-01-19 08:59:46 +08:00
|
|
|
)
|
2025-11-07 16:39:03 +08:00
|
|
|
|
2026-03-27 11:20:59 +08:00
|
|
|
self.reorder_batch_threshold = self.decode_threshold
|
2025-11-07 16:39:03 +08:00
|
|
|
|
2025-11-11 09:18:02 +08:00
|
|
|
scheduler_config = vllm_config.scheduler_config
|
2025-12-02 22:10:52 +08:00
|
|
|
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
2026-01-07 17:09:52 +08:00
|
|
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
2025-11-11 09:18:02 +08:00
|
|
|
|
2025-12-30 08:32:14 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_cudagraph_support(
|
|
|
|
|
cls: type["AscendAttentionMetadataBuilder"],
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
kv_cache_spec: AttentionSpec,
|
|
|
|
|
) -> AttentionCGSupport:
|
|
|
|
|
# Explicit override in case the underlying builder specialized this getter.
|
|
|
|
|
# @override omitted only because of mypy limitation due to type variable.
|
|
|
|
|
return AttentionCGSupport.ALWAYS
|
|
|
|
|
|
2026-01-19 08:59:46 +08:00
|
|
|
def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool:
|
2025-04-19 17:38:18 +08:00
|
|
|
return False
|
|
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
def build(
|
|
|
|
|
self,
|
2025-09-16 01:17:42 +08:00
|
|
|
common_prefix_len: int,
|
2025-08-20 09:01:04 +08:00
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
2025-12-28 10:35:07 +08:00
|
|
|
fast_build: bool = False,
|
|
|
|
|
) -> AscendMetadata:
|
2025-08-20 09:01:04 +08:00
|
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
|
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
2026-01-19 08:59:46 +08:00
|
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
2025-10-24 10:32:01 +08:00
|
|
|
|
2026-01-19 08:59:46 +08:00
|
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills(
|
|
|
|
|
common_attn_metadata, decode_threshold=self.decode_threshold
|
|
|
|
|
)
|
2025-10-24 10:32:01 +08:00
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
block_table = common_attn_metadata.block_table_tensor
|
|
|
|
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
2025-10-24 10:32:01 +08:00
|
|
|
|
2025-12-23 00:10:52 +08:00
|
|
|
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
2026-02-02 15:57:55 +08:00
|
|
|
# this slot_mapping override doesn't work since vllm will override it again. We should fix it vllm.
|
|
|
|
|
# see: https://github.com/vllm-project/vllm/blob/ce88756b967c2c5006746a424c15dd59a284ed8c/vllm/model_executor/layers/attention/cross_attention.py#L117
|
2026-01-11 11:38:45 +08:00
|
|
|
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
|
|
|
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
|
|
|
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
|
2026-03-13 14:07:35 +08:00
|
|
|
elif self.speculative_config and self.speculative_config.parallel_drafting:
|
|
|
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
attn_state = common_attn_metadata.attn_state
|
[Core]Append padding logic for Attention (#3256)
### What this PR does / why we need it?
This PR aims to add padding logic to seq_lens、block_tables when running
in full decode scenario. Before this PR, the number of input tokens with
padding might exceeds corresponding seq_lens. For example, when running
in full decode scenario:
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1]
query_start_loc: [0, 1, 2]
```
Here, `input_ids` is padded by 2 tokens while
`seq_lens`/`query_start_loc` are not. The mismatch between `input_ids`
and `seq_lens`/`query_start_loc` might cause some potential bugs. This
PR would change it into :
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1, 1, 1]
query_start_loc: [0, 1, 2, 3, 4]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: Angazenn <supperccell@163.com>
2025-10-17 21:56:01 +08:00
|
|
|
|
2026-01-07 17:09:52 +08:00
|
|
|
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
2026-01-19 08:59:46 +08:00
|
|
|
attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config)
|
2026-01-07 17:09:52 +08:00
|
|
|
|
|
|
|
|
swa_mask = None
|
2026-01-19 08:59:46 +08:00
|
|
|
is_swa = hasattr(self.model_config.hf_text_config, "sliding_window")
|
2026-01-07 17:09:52 +08:00
|
|
|
if self.model_config is not None and is_swa:
|
|
|
|
|
swa_mask = self.attn_mask_builder.get_swa_mask(
|
2026-01-19 08:59:46 +08:00
|
|
|
self.model_config.dtype, self.model_config.hf_text_config.sliding_window
|
|
|
|
|
)
|
2026-01-07 17:09:52 +08:00
|
|
|
|
2025-12-10 20:11:09 +08:00
|
|
|
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
2026-01-19 08:59:46 +08:00
|
|
|
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device, non_blocking=True)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-08-01 09:08:45 +08:00
|
|
|
attn_metadata = AscendMetadata(
|
|
|
|
|
num_actual_tokens=num_actual_tokens,
|
2025-10-24 10:32:01 +08:00
|
|
|
num_decode_tokens=num_decode_tokens,
|
2025-08-01 09:08:45 +08:00
|
|
|
block_tables=block_table,
|
|
|
|
|
query_start_loc=query_start_loc,
|
|
|
|
|
seq_lens=seq_lens,
|
2026-03-27 14:24:53 +08:00
|
|
|
seq_lens_cpu=seq_lens,
|
2025-10-17 11:19:41 +08:00
|
|
|
seq_lens_list=seq_lens.tolist(),
|
2025-08-20 09:01:04 +08:00
|
|
|
max_query_len=common_attn_metadata.max_query_len,
|
2025-10-17 11:19:41 +08:00
|
|
|
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
2025-08-01 09:08:45 +08:00
|
|
|
slot_mapping=slot_mapping,
|
|
|
|
|
attn_mask=attn_mask,
|
2025-12-29 14:56:25 +08:00
|
|
|
swa_mask=swa_mask,
|
2025-08-01 09:08:45 +08:00
|
|
|
attn_state=attn_state,
|
2025-10-24 10:32:01 +08:00
|
|
|
num_prefills=num_prefills,
|
2025-12-10 11:37:57 +08:00
|
|
|
num_decodes=num_decodes,
|
2025-12-19 14:57:09 +08:00
|
|
|
causal=common_attn_metadata.causal,
|
2026-01-19 08:59:46 +08:00
|
|
|
model_runner_type=self.model_config.runner_type,
|
|
|
|
|
)
|
2025-04-19 17:38:18 +08:00
|
|
|
return attn_metadata
|
|
|
|
|
|
2025-09-22 17:14:28 +08:00
|
|
|
def build_for_graph_capture(
|
|
|
|
|
self,
|
|
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
|
|
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
|
|
|
|
):
|
2026-03-06 17:11:22 +08:00
|
|
|
if attn_state in (
|
|
|
|
|
AscendAttentionState.DecodeOnly,
|
|
|
|
|
AscendAttentionState.ChunkedPrefill,
|
|
|
|
|
AscendAttentionState.SpecDecoding,
|
|
|
|
|
):
|
2025-09-22 17:14:28 +08:00
|
|
|
attn_metadata = self.build(
|
|
|
|
|
common_prefix_len=0,
|
|
|
|
|
common_attn_metadata=common_attn_metadata,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(
|
2026-01-04 12:03:21 +08:00
|
|
|
"Currently we only support building dummy metadata for DecodeOnly and ChunkedPrefill state"
|
2025-09-22 17:14:28 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attn_metadata.attn_state = attn_state
|
|
|
|
|
return attn_metadata
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
class AscendAttentionBackendImpl(AttentionImpl):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
scale: float,
|
|
|
|
|
num_kv_heads: int,
|
2026-01-19 08:59:46 +08:00
|
|
|
alibi_slopes: list[float] | None,
|
|
|
|
|
sliding_window: int | None,
|
2025-03-20 19:34:44 +08:00
|
|
|
kv_cache_dtype: str,
|
2026-01-19 08:59:46 +08:00
|
|
|
logits_soft_cap: float | None,
|
2025-07-24 10:23:34 +08:00
|
|
|
attn_type: str,
|
2026-01-19 08:59:46 +08:00
|
|
|
kv_sharing_target_layer_name: str | None,
|
2026-02-12 10:55:34 +08:00
|
|
|
sinks: torch.Tensor = None,
|
2025-07-24 10:23:34 +08:00
|
|
|
**kwargs,
|
2025-03-20 19:34:44 +08:00
|
|
|
) -> None:
|
2025-12-18 22:21:36 +08:00
|
|
|
self.vllm_config = get_current_vllm_config()
|
2025-03-20 19:34:44 +08:00
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.head_size = head_size
|
|
|
|
|
self.scale = float(scale)
|
|
|
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
|
|
|
self.hidden_size = self.num_heads * self.head_size
|
|
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
|
self.sliding_window = sliding_window
|
|
|
|
|
if alibi_slopes is not None:
|
2026-01-19 08:59:46 +08:00
|
|
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="npu")
|
2025-03-20 19:34:44 +08:00
|
|
|
self.alibi_slopes = alibi_slopes
|
|
|
|
|
self.attn_type = attn_type
|
|
|
|
|
|
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
|
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
2025-04-17 19:31:50 +08:00
|
|
|
self.key_cache = None
|
|
|
|
|
self.value_cache = None
|
2026-01-19 08:59:46 +08:00
|
|
|
self.is_kv_producer = (
|
|
|
|
|
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
|
|
|
|
)
|
2026-02-12 10:55:34 +08:00
|
|
|
self.sinks = sinks
|
2025-12-06 09:33:28 +08:00
|
|
|
|
2026-01-26 09:04:54 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def update_graph_params(
|
|
|
|
|
update_stream,
|
|
|
|
|
forward_context,
|
|
|
|
|
num_tokens,
|
|
|
|
|
vllm_config,
|
|
|
|
|
speculative_config=None,
|
|
|
|
|
num_dcp_pcp_tokens=None,
|
2026-01-28 14:41:18 +08:00
|
|
|
draft_attn_metadatas=None,
|
2026-01-26 09:04:54 +08:00
|
|
|
):
|
|
|
|
|
if using_paged_attention(num_tokens, vllm_config):
|
|
|
|
|
# Paged Attention update logic
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.is_draft_model:
|
2026-01-26 09:04:54 +08:00
|
|
|
graph_params = get_draft_graph_params()
|
|
|
|
|
else:
|
|
|
|
|
graph_params = get_graph_params()
|
|
|
|
|
with torch.npu.stream(update_stream):
|
|
|
|
|
for key, param, handle, event in zip(
|
|
|
|
|
forward_context.attn_metadata,
|
|
|
|
|
graph_params.attn_params[num_tokens],
|
|
|
|
|
graph_params.handles[num_tokens],
|
|
|
|
|
graph_params.events[num_tokens],
|
|
|
|
|
):
|
|
|
|
|
(
|
|
|
|
|
query,
|
|
|
|
|
key_cache,
|
|
|
|
|
value_cache,
|
|
|
|
|
num_kv_heads,
|
|
|
|
|
num_heads,
|
|
|
|
|
scale,
|
|
|
|
|
block_table,
|
|
|
|
|
seq_lens,
|
|
|
|
|
output,
|
|
|
|
|
) = param
|
|
|
|
|
seq_lens = forward_context.attn_metadata[key].seq_lens
|
|
|
|
|
|
|
|
|
|
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=key_cache,
|
|
|
|
|
value_cache=value_cache,
|
|
|
|
|
num_kv_heads=num_kv_heads,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
scale_value=scale,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
context_lens=seq_lens,
|
|
|
|
|
out=output,
|
|
|
|
|
)
|
|
|
|
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
|
|
|
|
torch_npu._npu_paged_attention(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=key_cache,
|
|
|
|
|
value_cache=value_cache,
|
|
|
|
|
num_kv_heads=num_kv_heads,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
scale_value=scale,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
context_lens=seq_lens,
|
|
|
|
|
out=output,
|
|
|
|
|
workspace=workspace,
|
|
|
|
|
)
|
|
|
|
|
torch.npu.graph_task_update_end(update_stream)
|
|
|
|
|
event.record(update_stream)
|
|
|
|
|
else:
|
|
|
|
|
# FIA update logic
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.is_draft_model:
|
2026-01-26 09:04:54 +08:00
|
|
|
graph_params = get_draft_graph_params()
|
2026-01-28 14:41:18 +08:00
|
|
|
attn_metadata = draft_attn_metadatas
|
2026-01-26 09:04:54 +08:00
|
|
|
attn_keys = list(attn_metadata[0].keys())
|
|
|
|
|
else:
|
|
|
|
|
graph_params = get_graph_params()
|
|
|
|
|
attn_metadata = forward_context.attn_metadata
|
|
|
|
|
attn_keys = list(attn_metadata.keys())
|
|
|
|
|
# For Qwen3-next, since the kv_cache_config has already categorized
|
|
|
|
|
# linear_attn and self_attn, the attn_metadata is first arranged with
|
|
|
|
|
# self_attn followed by linear_attn. Therefore, using zip directly
|
|
|
|
|
# filters out the update operations for linear_attn.
|
|
|
|
|
# TODO: We use a new variable `attn_keys` to ensure the loop count is
|
|
|
|
|
# correct after get by `zip` because of the new structure of the attn_metadata
|
|
|
|
|
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
|
|
|
|
|
num_layers = len(attn_keys)
|
|
|
|
|
if num_layers == 0:
|
|
|
|
|
return
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.is_draft_model:
|
2026-01-26 09:04:54 +08:00
|
|
|
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
|
|
|
|
|
attn_count = 0
|
|
|
|
|
with torch.npu.stream(update_stream):
|
|
|
|
|
for key, param, handle, event in zip(
|
|
|
|
|
attn_keys,
|
|
|
|
|
graph_params.attn_params[num_tokens],
|
|
|
|
|
graph_params.handles[num_tokens],
|
|
|
|
|
graph_params.events[num_tokens],
|
|
|
|
|
):
|
|
|
|
|
(
|
|
|
|
|
query,
|
|
|
|
|
key_cache,
|
|
|
|
|
value,
|
|
|
|
|
block_tables,
|
|
|
|
|
attn_mask,
|
|
|
|
|
block_size,
|
|
|
|
|
seq_lens,
|
|
|
|
|
query_start_loc,
|
|
|
|
|
num_kv_heads,
|
|
|
|
|
num_heads,
|
|
|
|
|
scale,
|
|
|
|
|
attn_output,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
) = param
|
|
|
|
|
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.is_draft_model:
|
2026-01-26 09:04:54 +08:00
|
|
|
draft_step = attn_count // num_layers
|
|
|
|
|
seq_lens = attn_metadata[draft_step][key].seq_lens_list
|
|
|
|
|
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
|
2026-03-21 16:57:22 +08:00
|
|
|
block_tables = attn_metadata[draft_step][key].block_tables
|
2026-01-26 09:04:54 +08:00
|
|
|
attn_count = attn_count + 1
|
|
|
|
|
else:
|
|
|
|
|
seq_lens = attn_metadata[key].seq_lens_list
|
|
|
|
|
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
|
2026-03-21 16:57:22 +08:00
|
|
|
block_tables = attn_metadata[key].block_tables
|
2026-01-26 09:04:54 +08:00
|
|
|
|
|
|
|
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
|
|
|
|
torch_npu.npu_fused_infer_attention_score.out(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key_cache,
|
|
|
|
|
value=value,
|
|
|
|
|
block_table=block_tables,
|
|
|
|
|
atten_mask=attn_mask,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
actual_seq_lengths=actual_seq_lengths_q,
|
|
|
|
|
actual_seq_lengths_kv=seq_lens,
|
|
|
|
|
num_key_value_heads=num_kv_heads,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
scale=scale,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
workspace=graph_params.workspaces.get(num_tokens),
|
|
|
|
|
out=[attn_output, softmax_lse],
|
|
|
|
|
)
|
|
|
|
|
torch.npu.graph_task_update_end(update_stream)
|
|
|
|
|
|
|
|
|
|
event.record(update_stream)
|
|
|
|
|
|
2026-01-10 22:57:57 +08:00
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
|
|
|
super().process_weights_after_loading(act_dtype)
|
|
|
|
|
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
|
|
|
|
flashcomm2_oshard_manager.post_process_after_loading()
|
|
|
|
|
|
2026-01-19 08:59:46 +08:00
|
|
|
def full_graph_fia(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
2025-11-17 10:50:35 +08:00
|
|
|
|
2025-12-19 14:57:09 +08:00
|
|
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.is_draft_model:
|
2025-12-29 09:54:51 +08:00
|
|
|
graph_params = get_draft_graph_params()
|
|
|
|
|
else:
|
|
|
|
|
graph_params = get_graph_params()
|
2025-12-19 14:57:09 +08:00
|
|
|
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
|
2025-11-17 10:50:35 +08:00
|
|
|
# Prepare tensors for attention output
|
|
|
|
|
# TODO: Refactor this to step-level instead of layer-level
|
|
|
|
|
|
|
|
|
|
# Get workspace from cache or calculate it if not present.
|
|
|
|
|
workspace = graph_params.workspaces.get(num_tokens)
|
|
|
|
|
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
|
|
|
|
|
if workspace is None:
|
|
|
|
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
2025-12-03 17:33:31 +08:00
|
|
|
atten_mask=attn_metadata.attn_mask,
|
2025-11-17 10:50:35 +08:00
|
|
|
block_table=block_table,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=block_size,
|
2025-12-19 14:57:09 +08:00
|
|
|
actual_seq_lengths=actual_seq_lengths_q,
|
2025-11-17 10:50:35 +08:00
|
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
)
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.is_draft_model:
|
2025-12-29 09:54:51 +08:00
|
|
|
update_draft_graph_params_workspaces(num_tokens, workspace)
|
|
|
|
|
else:
|
|
|
|
|
update_graph_params_workspaces(num_tokens, workspace)
|
2025-11-17 10:50:35 +08:00
|
|
|
|
|
|
|
|
# Handle graph capturing mode
|
|
|
|
|
stream = torch_npu.npu.current_stream()
|
|
|
|
|
|
|
|
|
|
event = torch.npu.ExternalEvent()
|
|
|
|
|
event.wait(stream)
|
|
|
|
|
event.reset(stream)
|
|
|
|
|
graph_params.events[num_tokens].append(event)
|
|
|
|
|
graph_params.attn_params[num_tokens].append(
|
2026-01-19 08:59:46 +08:00
|
|
|
(
|
|
|
|
|
weak_ref_tensors(query),
|
|
|
|
|
weak_ref_tensors(key),
|
|
|
|
|
weak_ref_tensors(value),
|
|
|
|
|
weak_ref_tensors(block_table),
|
|
|
|
|
weak_ref_tensors(attn_metadata.attn_mask),
|
|
|
|
|
block_size,
|
|
|
|
|
actual_seq_lengths_kv,
|
|
|
|
|
actual_seq_lengths_q,
|
|
|
|
|
self.num_kv_heads,
|
|
|
|
|
self.num_heads,
|
|
|
|
|
self.scale,
|
|
|
|
|
weak_ref_tensors(output),
|
|
|
|
|
weak_ref_tensors(softmax_lse),
|
|
|
|
|
)
|
|
|
|
|
)
|
2025-11-17 10:50:35 +08:00
|
|
|
|
|
|
|
|
torch.npu.graph_task_group_begin(stream)
|
|
|
|
|
torch_npu.npu_fused_infer_attention_score.out(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
2025-12-03 17:33:31 +08:00
|
|
|
atten_mask=attn_metadata.attn_mask,
|
2025-11-17 10:50:35 +08:00
|
|
|
block_table=block_table,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=block_size,
|
2025-12-19 14:57:09 +08:00
|
|
|
actual_seq_lengths=actual_seq_lengths_q,
|
2025-11-17 10:50:35 +08:00
|
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
workspace=workspace,
|
|
|
|
|
out=[output, softmax_lse],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
output = output.view(num_tokens, self.num_heads, self.head_size)
|
|
|
|
|
|
|
|
|
|
handle = torch.npu.graph_task_group_end(stream)
|
|
|
|
|
graph_params.handles[num_tokens].append(handle)
|
|
|
|
|
return output, num_tokens
|
|
|
|
|
|
2025-12-17 23:14:02 +08:00
|
|
|
def full_graph_pa(
|
2025-12-15 20:35:50 +08:00
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
2026-01-19 08:59:46 +08:00
|
|
|
output: torch.Tensor | None = None,
|
2025-12-15 20:35:50 +08:00
|
|
|
):
|
|
|
|
|
graph_params = get_graph_params()
|
|
|
|
|
num_tokens = query.shape[0]
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.capturing:
|
2025-12-15 20:35:50 +08:00
|
|
|
# Get workspace from cache or calculate it if not present.
|
|
|
|
|
workspace = graph_params.workspaces.get(num_tokens)
|
|
|
|
|
if workspace is None:
|
|
|
|
|
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
|
|
|
|
context_lens=attn_metadata.seq_lens,
|
2026-01-19 08:59:46 +08:00
|
|
|
out=output,
|
|
|
|
|
)
|
2025-12-31 16:54:04 +08:00
|
|
|
update_graph_params_workspaces(num_tokens, workspace)
|
2025-12-15 20:35:50 +08:00
|
|
|
|
|
|
|
|
# Handle graph capturing mode
|
|
|
|
|
stream = torch_npu.npu.current_stream()
|
|
|
|
|
|
|
|
|
|
event = torch.npu.ExternalEvent()
|
|
|
|
|
event.wait(stream)
|
|
|
|
|
event.reset(stream)
|
|
|
|
|
graph_params.events[num_tokens].append(event)
|
2026-01-19 08:59:46 +08:00
|
|
|
graph_params.attn_params[num_tokens].append(
|
|
|
|
|
(
|
|
|
|
|
weak_ref_tensors(query),
|
|
|
|
|
weak_ref_tensors(self.key_cache),
|
|
|
|
|
weak_ref_tensors(self.value_cache),
|
|
|
|
|
self.num_kv_heads,
|
|
|
|
|
self.num_heads,
|
|
|
|
|
self.scale,
|
|
|
|
|
attn_metadata.block_tables,
|
|
|
|
|
attn_metadata.seq_lens,
|
|
|
|
|
weak_ref_tensors(output),
|
|
|
|
|
)
|
|
|
|
|
)
|
2025-12-15 20:35:50 +08:00
|
|
|
|
|
|
|
|
torch.npu.graph_task_group_begin(stream)
|
|
|
|
|
torch_npu._npu_paged_attention(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
|
|
|
|
context_lens=attn_metadata.seq_lens,
|
|
|
|
|
out=output,
|
2026-01-19 08:59:46 +08:00
|
|
|
workspace=workspace,
|
|
|
|
|
)
|
2025-12-15 20:35:50 +08:00
|
|
|
handle = torch.npu.graph_task_group_end(stream)
|
|
|
|
|
graph_params.handles[num_tokens].append(handle)
|
|
|
|
|
return output
|
|
|
|
|
|
2026-01-19 08:59:46 +08:00
|
|
|
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata):
|
2025-12-02 09:13:26 +08:00
|
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
|
|
|
block_size = 128
|
|
|
|
|
block_table = None
|
2025-12-19 14:57:09 +08:00
|
|
|
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
2026-01-11 11:38:45 +08:00
|
|
|
if self.attn_type == AttentionType.ENCODER_DECODER:
|
2026-01-19 08:59:46 +08:00
|
|
|
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, dim=0).tolist()
|
|
|
|
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
2025-12-19 14:57:09 +08:00
|
|
|
batch_size = attn_metadata.seq_lens.shape[0]
|
2025-12-02 09:13:26 +08:00
|
|
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
|
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
2025-10-27 19:41:07 +08:00
|
|
|
key = self.key_cache.view( # type: ignore
|
2026-01-19 08:59:46 +08:00
|
|
|
num_block, block_size, -1
|
|
|
|
|
)
|
2025-10-27 19:41:07 +08:00
|
|
|
value = self.value_cache.view( # type: ignore
|
2026-01-19 08:59:46 +08:00
|
|
|
num_block, block_size, -1
|
|
|
|
|
)
|
2025-12-02 09:13:26 +08:00
|
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
2025-12-17 23:14:02 +08:00
|
|
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
|
|
|
key = self.key_cache.view( # type: ignore
|
2026-01-19 08:59:46 +08:00
|
|
|
num_block, block_size, -1
|
|
|
|
|
)
|
2025-12-17 23:14:02 +08:00
|
|
|
value = self.value_cache.view( # type: ignore
|
2026-01-19 08:59:46 +08:00
|
|
|
num_block, block_size, -1
|
|
|
|
|
)
|
2025-12-17 23:14:02 +08:00
|
|
|
block_table = attn_metadata.block_tables
|
|
|
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
|
|
|
# chunked prefill.
|
2025-10-27 19:41:07 +08:00
|
|
|
else:
|
2025-12-02 09:13:26 +08:00
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
|
|
|
key = self.key_cache.view( # type: ignore
|
2026-01-19 08:59:46 +08:00
|
|
|
num_block, block_size, -1
|
|
|
|
|
)
|
2025-12-02 09:13:26 +08:00
|
|
|
value = self.value_cache.view( # type: ignore
|
2026-01-19 08:59:46 +08:00
|
|
|
num_block, block_size, -1
|
|
|
|
|
)
|
2025-12-02 09:13:26 +08:00
|
|
|
block_table = attn_metadata.block_tables
|
|
|
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
2025-12-17 23:14:02 +08:00
|
|
|
return key, value, block_size, block_table, actual_seq_lengths_kv
|
|
|
|
|
|
2026-01-19 08:59:46 +08:00
|
|
|
def _forward_fia_slidingwindow(self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor):
|
2025-12-17 23:14:02 +08:00
|
|
|
batch_size = attn_metadata.seq_lens.shape[0]
|
|
|
|
|
block_size = 128
|
|
|
|
|
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
|
|
|
|
key = self.key_cache
|
|
|
|
|
value = self.value_cache
|
|
|
|
|
if self.key_cache is not None and self.value_cache is not None:
|
|
|
|
|
block_size = self.key_cache.shape[1]
|
|
|
|
|
key = self.key_cache.flatten(2, 3).contiguous()
|
|
|
|
|
value = self.value_cache.flatten(2, 3).contiguous()
|
|
|
|
|
|
2026-02-05 20:58:54 +08:00
|
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
2025-12-17 23:14:02 +08:00
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
input_layout="BSH",
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
pre_tokens=self.sliding_window,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
|
|
|
|
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
2026-01-19 08:59:46 +08:00
|
|
|
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
|
|
|
|
)
|
2025-12-17 23:14:02 +08:00
|
|
|
|
2026-02-05 20:58:54 +08:00
|
|
|
attn_output = attn_output.view(batch_size, self.num_heads, self.head_size)
|
|
|
|
|
output[:batch_size] = attn_output[:batch_size]
|
2025-12-17 23:14:02 +08:00
|
|
|
return output
|
2025-12-02 09:13:26 +08:00
|
|
|
|
2026-01-19 08:59:46 +08:00
|
|
|
def forward_fused_infer_attention(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
):
|
2025-12-29 15:28:34 +08:00
|
|
|
# we inherit ForwardContext in model runner v2, when enable model
|
|
|
|
|
# runner v2, there is not capturing attribute in forward_context,
|
|
|
|
|
# just use getattr to avoid attribute error.
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.capturing:
|
2026-01-19 08:59:46 +08:00
|
|
|
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
|
2025-12-17 23:14:02 +08:00
|
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
|
|
|
return output
|
2026-01-19 08:59:46 +08:00
|
|
|
if (
|
|
|
|
|
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
|
|
|
|
and self.sliding_window is not None
|
|
|
|
|
and attn_metadata.seq_lens.shape[0] == query.size(0)
|
2026-02-12 10:55:34 +08:00
|
|
|
and self.sinks is None
|
2026-01-19 08:59:46 +08:00
|
|
|
):
|
|
|
|
|
return self._forward_fia_slidingwindow(query, attn_metadata, output)
|
|
|
|
|
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
2025-12-02 09:13:26 +08:00
|
|
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
|
|
|
|
query = query[:num_tokens]
|
2026-01-19 08:59:46 +08:00
|
|
|
if (
|
|
|
|
|
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
|
|
|
|
|
and self.attn_type != AttentionType.ENCODER_DECODER
|
|
|
|
|
):
|
2025-12-26 21:37:28 +08:00
|
|
|
key = key[:num_tokens]
|
|
|
|
|
value = value[:num_tokens]
|
2025-12-02 09:13:26 +08:00
|
|
|
# Get workspace from cache or calculate it if not present.
|
2026-02-12 10:55:34 +08:00
|
|
|
if self.sinks is not None:
|
|
|
|
|
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
|
|
|
|
|
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
|
|
|
actual_seq_qlen = torch.tensor([1] * len(attn_metadata.seq_lens_list), dtype=torch.int32).cumsum(dim=0)
|
|
|
|
|
if self.sliding_window is not None:
|
|
|
|
|
atten_mask = attn_metadata.swa_mask
|
|
|
|
|
sparse_mode = 4
|
|
|
|
|
else:
|
|
|
|
|
atten_mask = attn_metadata.attn_mask
|
|
|
|
|
sparse_mode = 3
|
|
|
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
num_query_heads=self.num_heads,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
pre_tokens=self.sliding_window if self.sliding_window is not None else SWA_INT_MAX,
|
|
|
|
|
next_tokens=0,
|
|
|
|
|
atten_mask=atten_mask,
|
|
|
|
|
sparse_mode=sparse_mode,
|
|
|
|
|
softmax_scale=self.scale,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
actual_seq_qlen=actual_seq_qlen,
|
|
|
|
|
actual_seq_kvlen=actual_seq_lengths_kv,
|
|
|
|
|
learnable_sink=self.sinks,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
atten_mask=attn_metadata.attn_mask,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
|
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
)
|
2025-12-02 09:13:26 +08:00
|
|
|
|
2026-02-12 10:55:34 +08:00
|
|
|
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
|
2025-12-02 09:13:26 +08:00
|
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
2025-08-14 09:32:41 +08:00
|
|
|
return output
|
|
|
|
|
|
2025-12-17 23:14:02 +08:00
|
|
|
def forward_paged_attention(
|
2025-08-14 09:32:41 +08:00
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
2026-01-19 08:59:46 +08:00
|
|
|
output: torch.Tensor | None = None,
|
2025-08-14 09:32:41 +08:00
|
|
|
) -> torch.Tensor:
|
2026-03-13 09:11:46 +08:00
|
|
|
if _EXTRA_CTX.capturing:
|
2025-12-17 23:14:02 +08:00
|
|
|
return self.full_graph_pa(query, attn_metadata, output)
|
2026-01-19 08:59:46 +08:00
|
|
|
torch_npu._npu_paged_attention(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
|
|
|
|
context_lens=attn_metadata.seq_lens,
|
|
|
|
|
out=output,
|
|
|
|
|
)
|
2025-08-14 09:32:41 +08:00
|
|
|
return output
|
|
|
|
|
|
2026-01-19 08:59:46 +08:00
|
|
|
def _forward_encoder_attention(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
_: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
2026-04-01 16:57:33 +08:00
|
|
|
# use default sparse_mode 0 in normal scenario, which means no mask works on it
|
|
|
|
|
return torch_npu.npu_fusion_attention(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
head_num=self.num_heads,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
|
|
|
|
|
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
|
|
|
|
|
)[0]
|
2025-12-02 09:13:26 +08:00
|
|
|
|
2025-12-06 09:33:28 +08:00
|
|
|
def reshape_and_cache(
|
|
|
|
|
self,
|
2026-02-28 21:44:08 +08:00
|
|
|
query: torch.Tensor,
|
2025-12-06 09:33:28 +08:00
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
2026-01-19 08:59:46 +08:00
|
|
|
kv_cache: tuple[torch.Tensor],
|
2025-12-06 09:33:28 +08:00
|
|
|
attn_metadata: AscendMetadata,
|
2026-02-28 21:44:08 +08:00
|
|
|
output: torch.Tensor,
|
2025-12-06 09:33:28 +08:00
|
|
|
):
|
|
|
|
|
if len(kv_cache) > 1:
|
2025-12-31 15:09:01 +08:00
|
|
|
if self.is_kv_producer:
|
|
|
|
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
2025-12-06 09:33:28 +08:00
|
|
|
if self.key_cache is None:
|
|
|
|
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
2025-12-24 22:24:17 +08:00
|
|
|
slots = attn_metadata.slot_mapping
|
2026-01-19 08:59:46 +08:00
|
|
|
encoder_decoder = self.attn_type == AttentionType.ENCODER_DECODER
|
2026-01-13 09:53:26 +08:00
|
|
|
DeviceOperator.reshape_and_cache(
|
2026-01-19 08:59:46 +08:00
|
|
|
key=key[: attn_metadata.num_actual_tokens] if not encoder_decoder else key,
|
|
|
|
|
value=value[: attn_metadata.num_actual_tokens] if not encoder_decoder else value,
|
2026-01-13 09:53:26 +08:00
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
2026-02-02 15:57:55 +08:00
|
|
|
# quick fix to make sure slots is int32 for cross attention case.
|
|
|
|
|
# see: https://github.com/vllm-project/vllm/blob/ce88756b967c2c5006746a424c15dd59a284ed8c/vllm/model_executor/layers/attention/cross_attention.py#L117
|
|
|
|
|
slot_mapping=slots[: attn_metadata.num_actual_tokens] if not encoder_decoder else slots.to(torch.int32),
|
2026-01-19 08:59:46 +08:00
|
|
|
)
|
2025-12-31 15:09:01 +08:00
|
|
|
if self.is_kv_producer:
|
|
|
|
|
attn_metadata.reshape_cache_event.record()
|
2026-02-28 21:44:08 +08:00
|
|
|
return query, key, value, output
|
2025-12-06 09:33:28 +08:00
|
|
|
|
|
|
|
|
def forward_impl(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
2026-01-19 08:59:46 +08:00
|
|
|
kv_cache: tuple[torch.Tensor],
|
2025-12-06 09:33:28 +08:00
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
):
|
2025-12-17 23:14:02 +08:00
|
|
|
num_tokens = query.shape[0]
|
2026-01-19 08:59:46 +08:00
|
|
|
if (
|
|
|
|
|
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
|
|
|
|
and using_paged_attention(num_tokens, self.vllm_config)
|
|
|
|
|
and self.sliding_window is None
|
|
|
|
|
):
|
2025-12-17 23:14:02 +08:00
|
|
|
output = self.forward_paged_attention(query, attn_metadata, output)
|
2025-12-06 09:33:28 +08:00
|
|
|
else:
|
2026-01-19 08:59:46 +08:00
|
|
|
output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output)
|
2025-12-06 09:33:28 +08:00
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
2025-10-25 08:58:35 +08:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
2026-01-19 08:59:46 +08:00
|
|
|
kv_cache: tuple[torch.Tensor],
|
2025-10-25 08:58:35 +08:00
|
|
|
attn_metadata: AscendMetadata,
|
2026-01-19 08:59:46 +08:00
|
|
|
output: torch.Tensor | None = None,
|
|
|
|
|
output_scale: torch.Tensor | None = None,
|
|
|
|
|
output_block_scale: torch.Tensor | None = None,
|
2025-10-25 08:58:35 +08:00
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Forward pass with Ascend attention.
|
|
|
|
|
Args:
|
|
|
|
|
query: shape = [num_tokens, num_heads, head_size]
|
|
|
|
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
|
|
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
|
|
|
kv_cache: shape =
|
|
|
|
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
|
|
|
|
attn_metadata: Metadata for attention.
|
|
|
|
|
Returns:
|
|
|
|
|
shape = [num_tokens, num_heads * head_size]
|
|
|
|
|
"""
|
|
|
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
|
|
|
|
|
|
if output_scale is not None or output_block_scale is not None:
|
2026-01-19 08:59:46 +08:00
|
|
|
raise NotImplementedError("fused output quantization is not yet supported for AscendAttentionBackendImpl")
|
2025-10-25 08:58:35 +08:00
|
|
|
|
|
|
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
2025-12-02 09:13:26 +08:00
|
|
|
num_tokens = query.shape[0]
|
|
|
|
|
if attn_metadata is None:
|
|
|
|
|
return output.fill_(0)
|
2026-02-28 21:44:08 +08:00
|
|
|
output_padded = None
|
2026-01-11 11:38:45 +08:00
|
|
|
if key is not None and value is not None:
|
2026-02-28 21:44:08 +08:00
|
|
|
output_padded = output
|
|
|
|
|
query, key, value, output_padded = self.reshape_and_cache(
|
|
|
|
|
query, key, value, kv_cache, attn_metadata, output
|
|
|
|
|
)
|
2025-12-10 11:37:57 +08:00
|
|
|
# pooling model branch
|
2026-04-01 16:57:33 +08:00
|
|
|
if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal:
|
2026-01-19 08:59:46 +08:00
|
|
|
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
|
2025-12-02 09:13:26 +08:00
|
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
2025-12-06 09:33:28 +08:00
|
|
|
return output
|
2026-02-28 21:44:08 +08:00
|
|
|
if output_padded is not None:
|
|
|
|
|
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output_padded)
|
|
|
|
|
else:
|
|
|
|
|
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
|
|
|
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
2025-10-25 08:58:35 +08:00
|
|
|
return output
|
[v0.18.0]feat(quant): add C8 INT8 KV cache support for GQA attention models (#7474) (#8007)
backport of #7474
This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.
**Key changes:**
1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.
2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.
3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.
4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.
5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.
Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.
No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.
Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):
```
============ Serving Benchmark Result ============
Successful requests: 960
Failed requests: 0
Maximum request concurrency: 192
Benchmark duration (s): 1359.81
Total input tokens: 9830400
Total generated tokens: 1966080
Request throughput (req/s): 0.71
Output token throughput (tok/s): 1445.85
Peak output token throughput (tok/s): 2304.00
Total token throughput (tok/s): 8675.12
---------------Time to First Token----------------
Mean TTFT (ms): 24598.51
Median TTFT (ms): 23167.02
P50 TTFT (ms): 23167.02
P90 TTFT (ms): 47717.08
P99 TTFT (ms): 84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 120.76
Median TPOT (ms): 121.50
P50 TPOT (ms): 121.50
P90 TPOT (ms): 127.05
P99 TPOT (ms): 130.13
---------------Inter-token Latency----------------
Mean ITL (ms): 120.70
Median ITL (ms): 90.34
P50 ITL (ms): 90.34
P90 ITL (ms): 93.79
P99 ITL (ms): 101.80
==================================================
```
All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.
- vLLM version: v0.17.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c
Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
2026-04-08 10:51:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendC8AttentionBackendImpl(AscendAttentionBackendImpl):
|
|
|
|
|
"""Attention backend implementation for INT8 KV cache (C8/QuaRot) models.
|
|
|
|
|
|
|
|
|
|
This subclass handles static per-channel INT8 KV cache quantization.
|
|
|
|
|
It is activated via class surgery in AscendC8KVCacheAttentionMethod.create_weights
|
|
|
|
|
(vllm_ascend/quantization/methods/kv_c8.py)
|
|
|
|
|
so that C8 attention layers automatically use this forward path.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
kv_cache: tuple[torch.Tensor],
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor | None = None,
|
|
|
|
|
output_scale: torch.Tensor | None = None,
|
|
|
|
|
output_block_scale: torch.Tensor | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
|
|
|
|
|
|
if output_scale is not None or output_block_scale is not None:
|
|
|
|
|
raise NotImplementedError("fused output quantization is not yet supported for AscendC8AttentionBackendImpl")
|
|
|
|
|
|
|
|
|
|
num_tokens = query.shape[0]
|
|
|
|
|
if attn_metadata is None:
|
|
|
|
|
return output.fill_(0)
|
|
|
|
|
|
|
|
|
|
float_key, float_value = None, None
|
|
|
|
|
if key is not None and value is not None:
|
|
|
|
|
if attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
|
|
|
|
|
float_key, float_value = key, value
|
|
|
|
|
key, value = self._quantize_kv_to_int8(key, value, layer, attn_metadata.num_actual_tokens)
|
|
|
|
|
query, key, value, _ = self.reshape_and_cache(query, key, value, kv_cache, attn_metadata, output)
|
|
|
|
|
|
|
|
|
|
if attn_metadata.model_runner_type == "pooling":
|
|
|
|
|
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
|
|
|
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
self._prepare_c8_scales(layer, query.device)
|
|
|
|
|
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
|
|
|
return self._forward_c8_decode(query, attn_metadata, output, layer)
|
|
|
|
|
elif attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
|
|
|
|
|
return self._forward_c8_chunked_prefill(query, float_key, float_value, attn_metadata, output, layer)
|
|
|
|
|
else:
|
|
|
|
|
return self._forward_c8_fused_infer_attention(
|
|
|
|
|
query,
|
|
|
|
|
float_key if float_key is not None else key,
|
|
|
|
|
float_value if float_value is not None else value,
|
|
|
|
|
attn_metadata,
|
|
|
|
|
output,
|
|
|
|
|
layer,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _prepare_c8_scales(self, layer: AttentionLayer, device: torch.device) -> None:
|
|
|
|
|
"""Shard per-channel C8 scales/offsets to this TP rank and pre-compute
|
|
|
|
|
BF16 BNSD antiquant tensors for FIA V1 decode fast path.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(layer, "_c8_scales_prepared"):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def _shard_and_reshape(raw: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
if raw.numel() == 1:
|
|
|
|
|
return raw.to(device=device)
|
|
|
|
|
expected = self.num_kv_heads * self.head_size
|
|
|
|
|
if raw.numel() != expected:
|
|
|
|
|
total_kv_heads = raw.numel() // self.head_size
|
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
|
kv_head_start = tp_rank * total_kv_heads // tp_size
|
|
|
|
|
raw = raw.view(total_kv_heads, self.head_size)[
|
|
|
|
|
kv_head_start : kv_head_start + self.num_kv_heads
|
|
|
|
|
].contiguous()
|
|
|
|
|
return raw.view(1, self.num_kv_heads, self.head_size).to(device=device)
|
|
|
|
|
|
|
|
|
|
layer._c8_k_scale = _shard_and_reshape(layer.k_cache_scale.data)
|
|
|
|
|
layer._c8_k_offset = _shard_and_reshape(layer.k_cache_offset.data)
|
|
|
|
|
layer._c8_v_scale = _shard_and_reshape(layer.v_cache_scale.data)
|
|
|
|
|
layer._c8_v_offset = _shard_and_reshape(layer.v_cache_offset.data)
|
|
|
|
|
|
|
|
|
|
bnsd = (1, self.num_kv_heads, 1, self.head_size)
|
|
|
|
|
layer._c8_k_aq_scale = layer._c8_k_scale.to(torch.bfloat16).view(bnsd).contiguous()
|
|
|
|
|
layer._c8_k_aq_offset = layer._c8_k_offset.to(torch.bfloat16).view(bnsd).contiguous()
|
|
|
|
|
layer._c8_v_aq_scale = layer._c8_v_scale.to(torch.bfloat16).view(bnsd).contiguous()
|
|
|
|
|
layer._c8_v_aq_offset = layer._c8_v_offset.to(torch.bfloat16).view(bnsd).contiguous()
|
|
|
|
|
|
|
|
|
|
layer._c8_k_inv_scale_bf16 = (1.0 / layer._c8_k_scale).to(torch.bfloat16)
|
|
|
|
|
layer._c8_k_offset_bf16 = layer._c8_k_offset.to(torch.bfloat16)
|
|
|
|
|
layer._c8_v_inv_scale_bf16 = (1.0 / layer._c8_v_scale).to(torch.bfloat16)
|
|
|
|
|
layer._c8_v_offset_bf16 = layer._c8_v_offset.to(torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
layer._c8_scales_prepared = True
|
|
|
|
|
|
|
|
|
|
def _dequant_paged_kv_to_dense(
|
|
|
|
|
self,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
block_table: torch.Tensor,
|
|
|
|
|
seq_lens: list,
|
|
|
|
|
target_dtype: torch.dtype,
|
|
|
|
|
layer,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Gather paged INT8 KV blocks and dequantize to target_dtype."""
|
|
|
|
|
batch_size = block_table.shape[0]
|
|
|
|
|
block_size = key.shape[1]
|
|
|
|
|
H = key.shape[2]
|
|
|
|
|
max_blocks_per_seq = block_table.shape[1]
|
|
|
|
|
max_tokens_padded = max_blocks_per_seq * block_size
|
|
|
|
|
|
|
|
|
|
flat_ids = block_table.reshape(-1)
|
|
|
|
|
gathered_k = key[flat_ids].view(batch_size, max_tokens_padded, H)
|
|
|
|
|
gathered_v = value[flat_ids].view(batch_size, max_tokens_padded, H)
|
|
|
|
|
|
|
|
|
|
seq_lens_t = torch.tensor(seq_lens, dtype=torch.long, device=key.device)
|
|
|
|
|
positions = torch.arange(max_tokens_padded, dtype=torch.long, device=key.device)
|
|
|
|
|
valid_mask = (positions.unsqueeze(0) < seq_lens_t.unsqueeze(1)).view(-1)
|
|
|
|
|
|
|
|
|
|
dense_k = gathered_k.view(-1, H)[valid_mask]
|
|
|
|
|
dense_v = gathered_v.view(-1, H)[valid_mask]
|
|
|
|
|
|
|
|
|
|
dense_k = dense_k.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
dense_v = dense_v.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
k_scale = layer._c8_k_scale.to(target_dtype)
|
|
|
|
|
k_offset = layer._c8_k_offset.to(target_dtype)
|
|
|
|
|
v_scale = layer._c8_v_scale.to(target_dtype)
|
|
|
|
|
v_offset = layer._c8_v_offset.to(target_dtype)
|
|
|
|
|
dense_k = (dense_k.to(target_dtype) - k_offset) * k_scale
|
|
|
|
|
dense_v = (dense_v.to(target_dtype) - v_offset) * v_scale
|
|
|
|
|
return dense_k, dense_v
|
|
|
|
|
|
|
|
|
|
def _quantize_kv_to_int8(
|
|
|
|
|
self,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
num_actual_tokens: int,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Quantize K/V from float to INT8 using static per-channel C8 scales."""
|
|
|
|
|
self._prepare_c8_scales(layer, key.device)
|
|
|
|
|
|
|
|
|
|
actual_key = key[:num_actual_tokens]
|
|
|
|
|
actual_value = value[:num_actual_tokens]
|
|
|
|
|
|
|
|
|
|
k_int8 = torch.clamp(
|
|
|
|
|
torch.round(actual_key * layer._c8_k_inv_scale_bf16 + layer._c8_k_offset_bf16),
|
|
|
|
|
-128,
|
|
|
|
|
127,
|
|
|
|
|
).to(torch.int8)
|
|
|
|
|
v_int8 = torch.clamp(
|
|
|
|
|
torch.round(actual_value * layer._c8_v_inv_scale_bf16 + layer._c8_v_offset_bf16),
|
|
|
|
|
-128,
|
|
|
|
|
127,
|
|
|
|
|
).to(torch.int8)
|
|
|
|
|
return k_int8, v_int8
|
|
|
|
|
|
|
|
|
|
def _forward_c8_decode(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""C8 decode via FIA V1 BNSD with native paged INT8 KV + perchannel antiquant."""
|
|
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
|
|
|
|
|
assert block_size % 32 == 0, f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
|
|
|
|
|
key = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
|
|
|
|
|
value = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
|
|
|
|
|
batch_size = len(attn_metadata.seq_lens_list)
|
|
|
|
|
|
|
|
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
|
|
|
query[:batch_size].unsqueeze(2),
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
key_antiquant_scale=layer._c8_k_aq_scale,
|
|
|
|
|
key_antiquant_offset=layer._c8_k_aq_offset,
|
|
|
|
|
value_antiquant_scale=layer._c8_v_aq_scale,
|
|
|
|
|
value_antiquant_offset=layer._c8_v_aq_offset,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
|
|
|
|
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
input_layout="BNSD",
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
key_antiquant_mode=0,
|
|
|
|
|
value_antiquant_mode=0,
|
|
|
|
|
sparse_mode=0,
|
|
|
|
|
)
|
|
|
|
|
attn_output = attn_output.squeeze(2)
|
|
|
|
|
output[:batch_size] = attn_output
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def _forward_c8_chunked_prefill(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
float_key: torch.Tensor | None,
|
|
|
|
|
float_value: torch.Tensor | None,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""C8 ChunkedPrefill: decode via FIA V1 BNSD paged INT8 (zero gather),
|
|
|
|
|
prefill via FIA V1 TND with float KV (new) or gather+dequant (continuing).
|
|
|
|
|
"""
|
|
|
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
|
|
num_decodes = attn_metadata.num_decodes
|
|
|
|
|
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
|
|
|
|
|
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
|
|
|
|
|
|
|
|
|
|
if num_decode_tokens > 0:
|
|
|
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
|
|
|
|
|
assert block_size % 32 == 0, (
|
|
|
|
|
f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
|
|
|
|
|
)
|
|
|
|
|
kv_k = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
|
|
|
|
|
kv_v = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
|
|
|
|
|
|
|
|
|
|
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
|
|
|
|
|
query[:num_decode_tokens].unsqueeze(2),
|
|
|
|
|
kv_k,
|
|
|
|
|
kv_v,
|
|
|
|
|
key_antiquant_scale=layer._c8_k_aq_scale,
|
|
|
|
|
key_antiquant_offset=layer._c8_k_aq_offset,
|
|
|
|
|
value_antiquant_scale=layer._c8_v_aq_scale,
|
|
|
|
|
value_antiquant_offset=layer._c8_v_aq_offset,
|
|
|
|
|
block_table=attn_metadata.block_tables[:num_decodes],
|
|
|
|
|
actual_seq_lengths_kv=attn_metadata.seq_lens_list[:num_decodes],
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
input_layout="BNSD",
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
key_antiquant_mode=0,
|
|
|
|
|
value_antiquant_mode=0,
|
|
|
|
|
sparse_mode=0,
|
|
|
|
|
)
|
|
|
|
|
output[:num_decode_tokens] = attn_out.squeeze(2)
|
|
|
|
|
|
|
|
|
|
if attn_metadata.num_prefills > 0:
|
|
|
|
|
prefill_q = query[num_decode_tokens:num_tokens]
|
|
|
|
|
|
|
|
|
|
prefill_seq_qlen = [
|
|
|
|
|
actual_seq_qlen[i] - num_decode_tokens for i in range(num_decodes, len(actual_seq_qlen))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
all_new_prefill = True
|
|
|
|
|
for i in range(num_decodes, len(attn_metadata.seq_lens_list)):
|
|
|
|
|
q_start = actual_seq_qlen[i - 1] if i > 0 else 0
|
|
|
|
|
qlen_i = actual_seq_qlen[i] - q_start
|
|
|
|
|
if attn_metadata.seq_lens_list[i] > qlen_i:
|
|
|
|
|
all_new_prefill = False
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if all_new_prefill and float_key is not None and float_value is not None:
|
|
|
|
|
prefill_k = float_key[num_decode_tokens:num_tokens]
|
|
|
|
|
prefill_v = float_value[num_decode_tokens:num_tokens]
|
|
|
|
|
prefill_seq_kvlen = prefill_seq_qlen
|
|
|
|
|
else:
|
|
|
|
|
num_block, blk_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
|
|
|
|
|
paged_k = self.key_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
|
|
|
|
|
paged_v = self.value_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
|
|
|
|
|
prefill_bt = attn_metadata.block_tables[num_decodes:]
|
|
|
|
|
prefill_sl = attn_metadata.seq_lens_list[num_decodes:]
|
|
|
|
|
prefill_k, prefill_v = self._dequant_paged_kv_to_dense(
|
|
|
|
|
paged_k, paged_v, prefill_bt, prefill_sl, query.dtype, layer
|
|
|
|
|
)
|
|
|
|
|
prefill_seq_kvlen = torch.tensor(prefill_sl, dtype=torch.int32).cumsum(dim=0)
|
|
|
|
|
|
|
|
|
|
# block_table is None for prefill; FIA ignores block_size in this case.
|
|
|
|
|
# Use cache block_size for consistency rather than a magic number.
|
|
|
|
|
cache_block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
|
|
|
|
|
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
|
|
|
|
|
query=prefill_q,
|
|
|
|
|
key=prefill_k,
|
|
|
|
|
value=prefill_v,
|
|
|
|
|
atten_mask=attn_metadata.attn_mask,
|
|
|
|
|
block_table=None,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=cache_block_size,
|
|
|
|
|
actual_seq_lengths=prefill_seq_qlen,
|
|
|
|
|
actual_seq_lengths_kv=prefill_seq_kvlen,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
)
|
|
|
|
|
n_prefill = num_tokens - num_decode_tokens
|
|
|
|
|
attn_out = attn_out.view(n_prefill, self.num_heads, self.head_size)
|
|
|
|
|
output[num_decode_tokens:num_tokens] = attn_out[:n_prefill]
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def _forward_c8_fused_infer_attention(
|
|
|
|
|
self,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
):
|
|
|
|
|
"""C8 FIA V1 TND for prefill states (PrefillNoCache uses float KV directly,
|
|
|
|
|
PrefillCacheHit gathers + dequants paged INT8 KV).
|
|
|
|
|
"""
|
|
|
|
|
self._prepare_c8_scales(layer, query.device)
|
|
|
|
|
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
|
|
|
|
|
|
|
|
|
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
|
|
|
|
|
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
|
|
|
|
|
query = query[:num_tokens]
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
|
|
|
|
|
and self.attn_type != AttentionType.ENCODER_DECODER
|
|
|
|
|
):
|
|
|
|
|
key = key[:num_tokens]
|
|
|
|
|
value = value[:num_tokens]
|
|
|
|
|
|
|
|
|
|
if key.dtype == torch.int8:
|
|
|
|
|
if block_table is not None:
|
|
|
|
|
seq_lens = (
|
|
|
|
|
actual_seq_lengths_kv if isinstance(actual_seq_lengths_kv, list) else actual_seq_lengths_kv.tolist()
|
|
|
|
|
)
|
|
|
|
|
key, value = self._dequant_paged_kv_to_dense(key, value, block_table, seq_lens, query.dtype, layer)
|
|
|
|
|
block_table = None
|
|
|
|
|
# block_table is None after dequant; FIA ignores block_size.
|
|
|
|
|
# Use cache block_size for consistency rather than a magic number.
|
|
|
|
|
block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
|
|
|
|
|
actual_seq_lengths_kv = torch.tensor(seq_lens, dtype=torch.int32).cumsum(dim=0)
|
|
|
|
|
else:
|
|
|
|
|
qdt = query.dtype
|
|
|
|
|
k_scale = layer._c8_k_scale.to(qdt)
|
|
|
|
|
k_offset = layer._c8_k_offset.to(qdt)
|
|
|
|
|
v_scale = layer._c8_v_scale.to(qdt)
|
|
|
|
|
v_offset = layer._c8_v_offset.to(qdt)
|
|
|
|
|
key = (key.to(qdt) - k_offset) * k_scale
|
|
|
|
|
value = (value.to(qdt) - v_offset) * v_scale
|
|
|
|
|
|
|
|
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
atten_mask=attn_metadata.attn_mask,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
input_layout="TND",
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
actual_seq_lengths=actual_seq_qlen,
|
|
|
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
|
|
|
num_key_value_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale=self.scale,
|
|
|
|
|
sparse_mode=3,
|
|
|
|
|
)
|
|
|
|
|
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
|
|
|
|
|
output[:num_tokens] = attn_output
|
|
|
|
|
return output
|