Files
xc-llm-ascend/vllm_ascend/attention/attention_v1.py
Mengqing Cao 044d4c3974 [v0.18.0]feat(quant): add C8 INT8 KV cache support for GQA attention models (#7474) (#8007)
backport of #7474

This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.

**Key changes:**

1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.

2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.

3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.

4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.

5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.

Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.

No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.

Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):

```
============ Serving Benchmark Result ============
Successful requests:                     960
Failed requests:                         0
Maximum request concurrency:             192
Benchmark duration (s):                  1359.81
Total input tokens:                      9830400
Total generated tokens:                  1966080
Request throughput (req/s):              0.71
Output token throughput (tok/s):         1445.85
Peak output token throughput (tok/s):    2304.00
Total token throughput (tok/s):          8675.12
---------------Time to First Token----------------
Mean TTFT (ms):                          24598.51
Median TTFT (ms):                        23167.02
P50 TTFT (ms):                           23167.02
P90 TTFT (ms):                           47717.08
P99 TTFT (ms):                           84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          120.76
Median TPOT (ms):                        121.50
P50 TPOT (ms):                           121.50
P90 TPOT (ms):                           127.05
P99 TPOT (ms):                           130.13
---------------Inter-token Latency----------------
Mean ITL (ms):                           120.70
Median ITL (ms):                         90.34
P50 ITL (ms):                            90.34
P90 ITL (ms):                            93.79
P99 ITL (ms):                            101.80
==================================================
```

All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
2026-04-08 10:51:58 +08:00

1343 lines
55 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from dataclasses import dataclass
from enum import Enum
import torch
import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
)
from vllm.v1.attention.backends.registry import ( # type: ignore
AttentionBackendEnum,
register_backend,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata,
enable_cp,
split_decodes_and_prefills,
using_paged_attention,
)
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params,
get_graph_params,
update_draft_graph_params_workspaces,
update_graph_params_workspaces,
)
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import weak_ref_tensors
# default max value of sliding window size
SWA_INT_MAX = 2147483647
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type["AscendAttentionBackendImpl"]:
if enable_cp():
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl
return AscendAttentionCPImpl
return AscendAttentionBackendImpl
@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
if enable_cp():
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPMetadataBuilder
return AscendAttentionCPMetadataBuilder
return AscendAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: list[torch.Tensor],
dst_kv_cache: list[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(dst_key_cache.device)
@staticmethod
def copy_blocks(
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
dst_indices = src_to_dists[:, 1]
for kv_cache in kv_caches:
key_caches = kv_cache[0]
value_caches = kv_cache[1]
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int]:
return [128]
class AscendAttentionState(Enum):
PrefillNoCache = 0
PrefillCacheHit = 1
DecodeOnly = 2
ChunkedPrefill = 3
SpecDecoding = 4
@dataclass
class AscendMetadata:
"""
Per-layer attention metadata for Ascend FlashAttention backend.
Contains attention masks, token counts, sequence lengths and KV cache
related properties for attention computation.
"""
# **************************** Basic Properties ************************** #
attn_mask: torch.Tensor | None = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
# Number of tokens excluding padding.
num_actual_tokens_pcp_padded: int = 0
num_actual_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
num_decodes_flatten: int = 0
# The sequence length per sequence. Sequence length means the computed
# tokens + new tokens (is None if it is a decoding).
# (batch_size,)
# TODO(Angazenn): The following parameters are quite redundant and
# contains similar information (such as seq_lens seq_lens_list). We
# should simplified these parameters once attention schema in vLLM-Ascend
# is unified.
seq_lens: torch.Tensor = None
seq_lens_cpu: torch.Tensor = None
seq_lens_list: list[int] = None # type: ignore
actual_seq_lengths_q: list[int] = None # type: ignore
query_start_loc: torch.Tensor = None
# Maximum query length in the batch (None for decoding).
max_query_len: int | None = None
# ********************** KV Cache Related Properties ********************* #
# Block addresses per sequence (Seq id -> list of physical block).
# (batch_size, max_blocks_per_seq)
block_tables: torch.Tensor = None
# The indices of the token slots that input tokens will be stored into.
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
# and 1st slot in block 1, respectively.
# (num_tokens,)
slot_mapping: torch.Tensor = None
# pcp
prefill: AscendMetadataForPrefill | None = None
# dcp
decode_meta: AscendMetadataForDecode | None = None
causal: bool = True
# runner_type in model_config.
model_runner_type: str = ""
# prefill reshape_and_cache event
reshape_cache_event: torch.npu.Event = None
# sliding window attention mask
swa_mask: torch.Tensor | None = None
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
"""
Builder for constructing AscendMetadata from CommonAttentionMetadata.
Handles attention mask generation and metadata preparation for
Ascend FlashAttention backend.
"""
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.compilation_config = vllm_config.compilation_config
self.device = device
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len, AscendAttentionBackend.get_supported_kernel_block_sizes()[0]
)
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, (
f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
)
self.reorder_batch_threshold = self.decode_threshold
scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@classmethod
def get_cudagraph_support(
cls: type["AscendAttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.ALWAYS
def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool:
return False
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
fast_build: bool = False,
) -> AscendMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.decode_threshold
)
block_table = common_attn_metadata.block_table_tensor
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
# this slot_mapping override doesn't work since vllm will override it again. We should fix it vllm.
# see: https://github.com/vllm-project/vllm/blob/ce88756b967c2c5006746a424c15dd59a284ed8c/vllm/model_executor/layers/attention/cross_attention.py#L117
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
seq_lens = common_attn_metadata.seq_lens
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
elif self.speculative_config and self.speculative_config.parallel_drafting:
seq_lens = common_attn_metadata.seq_lens
attn_state = common_attn_metadata.attn_state
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config)
swa_mask = None
is_swa = hasattr(self.model_config.hf_text_config, "sliding_window")
if self.model_config is not None and is_swa:
swa_mask = self.attn_mask_builder.get_swa_mask(
self.model_config.dtype, self.model_config.hf_text_config.sliding_window
)
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device, non_blocking=True)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len,
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping,
attn_mask=attn_mask,
swa_mask=swa_mask,
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes,
causal=common_attn_metadata.causal,
model_runner_type=self.model_config.runner_type,
)
return attn_metadata
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state in (
AscendAttentionState.DecodeOnly,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
):
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly and ChunkedPrefill state"
)
attn_metadata.attn_state = attn_state
return attn_metadata
class AscendAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
sinks: torch.Tensor = None,
**kwargs,
) -> None:
self.vllm_config = get_current_vllm_config()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.hidden_size = self.num_heads * self.head_size
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
self.sinks = sinks
@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
draft_attn_metadatas=None,
):
if using_paged_attention(num_tokens, vllm_config):
# Paged Attention update logic
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
else:
# FIA update logic
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
attn_metadata = draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
else:
graph_params = get_graph_params()
attn_metadata = forward_context.attn_metadata
attn_keys = list(attn_metadata.keys())
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
# TODO: We use a new variable `attn_keys` to ensure the loop count is
# correct after get by `zip` because of the new structure of the attn_metadata
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
num_layers = len(attn_keys)
if num_layers == 0:
return
if _EXTRA_CTX.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
attn_keys,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
if _EXTRA_CTX.is_draft_model:
draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
block_tables = attn_metadata[draft_step][key].block_tables
attn_count = attn_count + 1
else:
seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
block_tables = attn_metadata[key].block_tables
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key_cache,
value=value,
block_table=block_tables,
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
flashcomm2_oshard_manager.post_process_after_loading()
def full_graph_fia(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
) -> torch.Tensor:
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
sparse_mode=3,
scale=self.scale,
)
if _EXTRA_CTX.is_draft_model:
update_draft_graph_params_workspaces(num_tokens, workspace)
else:
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(query),
weak_ref_tensors(key),
weak_ref_tensors(value),
weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.attn_mask),
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
self.num_kv_heads,
self.num_heads,
self.scale,
weak_ref_tensors(output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
workspace=workspace,
out=[output, softmax_lse],
)
output = output.view(num_tokens, self.num_heads, self.head_size)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output, num_tokens
def full_graph_pa(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
):
graph_params = get_graph_params()
num_tokens = query.shape[0]
if _EXTRA_CTX.capturing:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
)
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace,
)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
if self.attn_type == AttentionType.ENCODER_DECODER:
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, dim=0).tolist()
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.seq_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1
)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1
)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# chunked prefill.
else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1
)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
return key, value, block_size, block_table, actual_seq_lengths_kv
def _forward_fia_slidingwindow(self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor):
batch_size = attn_metadata.seq_lens.shape[0]
block_size = 128
query = query.view(batch_size, 1, self.num_heads * self.head_size)
key = self.key_cache
value = self.value_cache
if self.key_cache is not None and self.value_cache is not None:
block_size = self.key_cache.shape[1]
key = self.key_cache.flatten(2, 3).contiguous()
value = self.value_cache.flatten(2, 3).contiguous()
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSH",
block_size=block_size,
pre_tokens=self.sliding_window,
scale=self.scale,
block_table=attn_metadata.block_tables,
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
actual_seq_lengths_kv=attn_metadata.seq_lens,
)
attn_output = attn_output.view(batch_size, self.num_heads, self.head_size)
output[:batch_size] = attn_output[:batch_size]
return output
def forward_fused_infer_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
# we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context,
# just use getattr to avoid attribute error.
if _EXTRA_CTX.capturing:
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
if (
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and self.sliding_window is not None
and attn_metadata.seq_lens.shape[0] == query.size(0)
and self.sinks is None
):
return self._forward_fia_slidingwindow(query, attn_metadata, output)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
if (
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
and self.attn_type != AttentionType.ENCODER_DECODER
):
key = key[:num_tokens]
value = value[:num_tokens]
# Get workspace from cache or calculate it if not present.
if self.sinks is not None:
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
actual_seq_qlen = torch.tensor([1] * len(attn_metadata.seq_lens_list), dtype=torch.int32).cumsum(dim=0)
if self.sliding_window is not None:
atten_mask = attn_metadata.swa_mask
sparse_mode = 4
else:
atten_mask = attn_metadata.attn_mask
sparse_mode = 3
attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(
query,
key,
value,
num_query_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND",
pre_tokens=self.sliding_window if self.sliding_window is not None else SWA_INT_MAX,
next_tokens=0,
atten_mask=atten_mask,
sparse_mode=sparse_mode,
softmax_scale=self.scale,
block_table=block_table,
block_size=block_size,
actual_seq_qlen=actual_seq_qlen,
actual_seq_kvlen=actual_seq_lengths_kv,
learnable_sink=self.sinks,
)
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output[:num_tokens]
return output
def forward_paged_attention(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
) -> torch.Tensor:
if _EXTRA_CTX.capturing:
return self.full_graph_pa(query, attn_metadata, output)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
_: torch.Tensor,
) -> torch.Tensor:
# use default sparse_mode 0 in normal scenario, which means no mask works on it
return torch_npu.npu_fusion_attention(
query=query,
key=key,
value=value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
)[0]
def reshape_and_cache(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
if len(kv_cache) > 1:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
encoder_decoder = self.attn_type == AttentionType.ENCODER_DECODER
DeviceOperator.reshape_and_cache(
key=key[: attn_metadata.num_actual_tokens] if not encoder_decoder else key,
value=value[: attn_metadata.num_actual_tokens] if not encoder_decoder else value,
key_cache=self.key_cache,
value_cache=self.value_cache,
# quick fix to make sure slots is int32 for cross attention case.
# see: https://github.com/vllm-project/vllm/blob/ce88756b967c2c5006746a424c15dd59a284ed8c/vllm/model_executor/layers/attention/cross_attention.py#L117
slot_mapping=slots[: attn_metadata.num_actual_tokens] if not encoder_decoder else slots.to(torch.int32),
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return query, key, value, output
def forward_impl(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
num_tokens = query.shape[0]
if (
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and using_paged_attention(num_tokens, self.vllm_config)
and self.sliding_window is None
):
output = self.forward_paged_attention(query, attn_metadata, output)
else:
output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output)
return output
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError("fused output quantization is not yet supported for AscendAttentionBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
output_padded = None
if key is not None and value is not None:
output_padded = output
query, key, value, output_padded = self.reshape_and_cache(
query, key, value, kv_cache, attn_metadata, output
)
# pooling model branch
if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal:
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
if output_padded is not None:
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output_padded)
else:
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
class AscendC8AttentionBackendImpl(AscendAttentionBackendImpl):
"""Attention backend implementation for INT8 KV cache (C8/QuaRot) models.
This subclass handles static per-channel INT8 KV cache quantization.
It is activated via class surgery in AscendC8KVCacheAttentionMethod.create_weights
(vllm_ascend/quantization/methods/kv_c8.py)
so that C8 attention layers automatically use this forward path.
"""
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError("fused output quantization is not yet supported for AscendC8AttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
float_key, float_value = None, None
if key is not None and value is not None:
if attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
float_key, float_value = key, value
key, value = self._quantize_kv_to_int8(key, value, layer, attn_metadata.num_actual_tokens)
query, key, value, _ = self.reshape_and_cache(query, key, value, kv_cache, attn_metadata, output)
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
self._prepare_c8_scales(layer, query.device)
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
return self._forward_c8_decode(query, attn_metadata, output, layer)
elif attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
return self._forward_c8_chunked_prefill(query, float_key, float_value, attn_metadata, output, layer)
else:
return self._forward_c8_fused_infer_attention(
query,
float_key if float_key is not None else key,
float_value if float_value is not None else value,
attn_metadata,
output,
layer,
)
def _prepare_c8_scales(self, layer: AttentionLayer, device: torch.device) -> None:
"""Shard per-channel C8 scales/offsets to this TP rank and pre-compute
BF16 BNSD antiquant tensors for FIA V1 decode fast path.
"""
if hasattr(layer, "_c8_scales_prepared"):
return
def _shard_and_reshape(raw: torch.Tensor) -> torch.Tensor:
if raw.numel() == 1:
return raw.to(device=device)
expected = self.num_kv_heads * self.head_size
if raw.numel() != expected:
total_kv_heads = raw.numel() // self.head_size
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
kv_head_start = tp_rank * total_kv_heads // tp_size
raw = raw.view(total_kv_heads, self.head_size)[
kv_head_start : kv_head_start + self.num_kv_heads
].contiguous()
return raw.view(1, self.num_kv_heads, self.head_size).to(device=device)
layer._c8_k_scale = _shard_and_reshape(layer.k_cache_scale.data)
layer._c8_k_offset = _shard_and_reshape(layer.k_cache_offset.data)
layer._c8_v_scale = _shard_and_reshape(layer.v_cache_scale.data)
layer._c8_v_offset = _shard_and_reshape(layer.v_cache_offset.data)
bnsd = (1, self.num_kv_heads, 1, self.head_size)
layer._c8_k_aq_scale = layer._c8_k_scale.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_k_aq_offset = layer._c8_k_offset.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_v_aq_scale = layer._c8_v_scale.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_v_aq_offset = layer._c8_v_offset.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_k_inv_scale_bf16 = (1.0 / layer._c8_k_scale).to(torch.bfloat16)
layer._c8_k_offset_bf16 = layer._c8_k_offset.to(torch.bfloat16)
layer._c8_v_inv_scale_bf16 = (1.0 / layer._c8_v_scale).to(torch.bfloat16)
layer._c8_v_offset_bf16 = layer._c8_v_offset.to(torch.bfloat16)
layer._c8_scales_prepared = True
def _dequant_paged_kv_to_dense(
self,
key: torch.Tensor,
value: torch.Tensor,
block_table: torch.Tensor,
seq_lens: list,
target_dtype: torch.dtype,
layer,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Gather paged INT8 KV blocks and dequantize to target_dtype."""
batch_size = block_table.shape[0]
block_size = key.shape[1]
H = key.shape[2]
max_blocks_per_seq = block_table.shape[1]
max_tokens_padded = max_blocks_per_seq * block_size
flat_ids = block_table.reshape(-1)
gathered_k = key[flat_ids].view(batch_size, max_tokens_padded, H)
gathered_v = value[flat_ids].view(batch_size, max_tokens_padded, H)
seq_lens_t = torch.tensor(seq_lens, dtype=torch.long, device=key.device)
positions = torch.arange(max_tokens_padded, dtype=torch.long, device=key.device)
valid_mask = (positions.unsqueeze(0) < seq_lens_t.unsqueeze(1)).view(-1)
dense_k = gathered_k.view(-1, H)[valid_mask]
dense_v = gathered_v.view(-1, H)[valid_mask]
dense_k = dense_k.view(-1, self.num_kv_heads, self.head_size)
dense_v = dense_v.view(-1, self.num_kv_heads, self.head_size)
k_scale = layer._c8_k_scale.to(target_dtype)
k_offset = layer._c8_k_offset.to(target_dtype)
v_scale = layer._c8_v_scale.to(target_dtype)
v_offset = layer._c8_v_offset.to(target_dtype)
dense_k = (dense_k.to(target_dtype) - k_offset) * k_scale
dense_v = (dense_v.to(target_dtype) - v_offset) * v_scale
return dense_k, dense_v
def _quantize_kv_to_int8(
self,
key: torch.Tensor,
value: torch.Tensor,
layer: AttentionLayer,
num_actual_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize K/V from float to INT8 using static per-channel C8 scales."""
self._prepare_c8_scales(layer, key.device)
actual_key = key[:num_actual_tokens]
actual_value = value[:num_actual_tokens]
k_int8 = torch.clamp(
torch.round(actual_key * layer._c8_k_inv_scale_bf16 + layer._c8_k_offset_bf16),
-128,
127,
).to(torch.int8)
v_int8 = torch.clamp(
torch.round(actual_value * layer._c8_v_inv_scale_bf16 + layer._c8_v_offset_bf16),
-128,
127,
).to(torch.int8)
return k_int8, v_int8
def _forward_c8_decode(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
) -> torch.Tensor:
"""C8 decode via FIA V1 BNSD with native paged INT8 KV + perchannel antiquant."""
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
assert block_size % 32 == 0, f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
key = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
value = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
batch_size = len(attn_metadata.seq_lens_list)
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query[:batch_size].unsqueeze(2),
key,
value,
key_antiquant_scale=layer._c8_k_aq_scale,
key_antiquant_offset=layer._c8_k_aq_offset,
value_antiquant_scale=layer._c8_v_aq_scale,
value_antiquant_offset=layer._c8_v_aq_offset,
block_table=attn_metadata.block_tables,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
scale=self.scale,
block_size=block_size,
key_antiquant_mode=0,
value_antiquant_mode=0,
sparse_mode=0,
)
attn_output = attn_output.squeeze(2)
output[:batch_size] = attn_output
return output
def _forward_c8_chunked_prefill(
self,
query: torch.Tensor,
float_key: torch.Tensor | None,
float_value: torch.Tensor | None,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
) -> torch.Tensor:
"""C8 ChunkedPrefill: decode via FIA V1 BNSD paged INT8 (zero gather),
prefill via FIA V1 TND with float KV (new) or gather+dequant (continuing).
"""
num_decode_tokens = attn_metadata.num_decode_tokens
num_decodes = attn_metadata.num_decodes
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
if num_decode_tokens > 0:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
assert block_size % 32 == 0, (
f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
)
kv_k = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
kv_v = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
query[:num_decode_tokens].unsqueeze(2),
kv_k,
kv_v,
key_antiquant_scale=layer._c8_k_aq_scale,
key_antiquant_offset=layer._c8_k_aq_offset,
value_antiquant_scale=layer._c8_v_aq_scale,
value_antiquant_offset=layer._c8_v_aq_offset,
block_table=attn_metadata.block_tables[:num_decodes],
actual_seq_lengths_kv=attn_metadata.seq_lens_list[:num_decodes],
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
scale=self.scale,
block_size=block_size,
key_antiquant_mode=0,
value_antiquant_mode=0,
sparse_mode=0,
)
output[:num_decode_tokens] = attn_out.squeeze(2)
if attn_metadata.num_prefills > 0:
prefill_q = query[num_decode_tokens:num_tokens]
prefill_seq_qlen = [
actual_seq_qlen[i] - num_decode_tokens for i in range(num_decodes, len(actual_seq_qlen))
]
all_new_prefill = True
for i in range(num_decodes, len(attn_metadata.seq_lens_list)):
q_start = actual_seq_qlen[i - 1] if i > 0 else 0
qlen_i = actual_seq_qlen[i] - q_start
if attn_metadata.seq_lens_list[i] > qlen_i:
all_new_prefill = False
break
if all_new_prefill and float_key is not None and float_value is not None:
prefill_k = float_key[num_decode_tokens:num_tokens]
prefill_v = float_value[num_decode_tokens:num_tokens]
prefill_seq_kvlen = prefill_seq_qlen
else:
num_block, blk_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
paged_k = self.key_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
paged_v = self.value_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
prefill_bt = attn_metadata.block_tables[num_decodes:]
prefill_sl = attn_metadata.seq_lens_list[num_decodes:]
prefill_k, prefill_v = self._dequant_paged_kv_to_dense(
paged_k, paged_v, prefill_bt, prefill_sl, query.dtype, layer
)
prefill_seq_kvlen = torch.tensor(prefill_sl, dtype=torch.int32).cumsum(dim=0)
# block_table is None for prefill; FIA ignores block_size in this case.
# Use cache block_size for consistency rather than a magic number.
cache_block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
query=prefill_q,
key=prefill_k,
value=prefill_v,
atten_mask=attn_metadata.attn_mask,
block_table=None,
input_layout="TND",
block_size=cache_block_size,
actual_seq_lengths=prefill_seq_qlen,
actual_seq_lengths_kv=prefill_seq_kvlen,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
n_prefill = num_tokens - num_decode_tokens
attn_out = attn_out.view(n_prefill, self.num_heads, self.head_size)
output[num_decode_tokens:num_tokens] = attn_out[:n_prefill]
return output
def _forward_c8_fused_infer_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
):
"""C8 FIA V1 TND for prefill states (PrefillNoCache uses float KV directly,
PrefillCacheHit gathers + dequants paged INT8 KV).
"""
self._prepare_c8_scales(layer, query.device)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
query = query[:num_tokens]
if (
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
and self.attn_type != AttentionType.ENCODER_DECODER
):
key = key[:num_tokens]
value = value[:num_tokens]
if key.dtype == torch.int8:
if block_table is not None:
seq_lens = (
actual_seq_lengths_kv if isinstance(actual_seq_lengths_kv, list) else actual_seq_lengths_kv.tolist()
)
key, value = self._dequant_paged_kv_to_dense(key, value, block_table, seq_lens, query.dtype, layer)
block_table = None
# block_table is None after dequant; FIA ignores block_size.
# Use cache block_size for consistency rather than a magic number.
block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
actual_seq_lengths_kv = torch.tensor(seq_lens, dtype=torch.int32).cumsum(dim=0)
else:
qdt = query.dtype
k_scale = layer._c8_k_scale.to(qdt)
k_offset = layer._c8_k_offset.to(qdt)
v_scale = layer._c8_v_scale.to(qdt)
v_offset = layer._c8_v_offset.to(qdt)
key = (key.to(qdt) - k_offset) * k_scale
value = (value.to(qdt) - v_offset) * v_scale
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_qlen,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output
return output