[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)

### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.

Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.

The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.


### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2026-03-13 14:47:42 +08:00
committed by GitHub
parent df1ee8070d
commit 7ed9e9de69
24 changed files with 4279 additions and 77 deletions

View File

@@ -134,9 +134,12 @@ class AscendConfig:
bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant()
)
use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
if self.enable_kv_nz:
use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
raise RuntimeError("enable_kv_nz is only supported for mla currently.")
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
@@ -144,6 +147,17 @@ class AscendConfig:
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
# Disable Sparse C8 for A5
# A5 has not been fully validated for this path and may carry hidden risks.
# TODO(rjg-lyh): Enable A5 support after sufficient validation.
self.enable_sparse_c8 = (
additional_config.get("enable_sparse_c8", False)
and use_sparse
and get_ascend_device_type() != AscendDeviceType.A5
)
def _construct_weight_prefetch_config(self, additional_config):
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeVar
import scipy # type: ignore
import torch
import torch_npu
import vllm.envs as envs_vllm
@@ -355,6 +356,9 @@ class AscendSFAImpl(MLAAttentionImpl):
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
o_proj_full_pool: torch.Tensor | None = None
# qk_hadamard tensor shared when dsa c8 enabled
qk_hadamard: torch.Tensor | None = None
def __init__(
self,
num_heads: int,
@@ -425,6 +429,12 @@ class AscendSFAImpl(MLAAttentionImpl):
self.is_rope_neox_style = False
self.use_torch_npu_lightning_indexer = True
# dsa c8
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
# Effective in SFA when FlashComm is enabled.
self.enable_dsa_cp = enable_dsa_cp()
@@ -515,6 +525,11 @@ class AscendSFAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
128**0.5
)
# Processing the input parameters for MLAPO by reordering and transposing
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
# and handling quantization parameters.
@@ -874,7 +889,15 @@ class AscendSFAImpl(MLAAttentionImpl):
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
return k_li
if self.use_sparse_c8_indexer:
k_li = k_li @ AscendSFAImpl.qk_hadamard
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
else:
k_li_scale = None
return k_li, k_li_scale
def indexer_select_post_process(
self,
@@ -905,10 +928,35 @@ class AscendSFAImpl(MLAAttentionImpl):
q_li_pe = q_li_pe.squeeze(2)
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
if self.use_sparse_c8_indexer:
q_li_shape_ori = q_li.shape
q_li = q_li @ AscendSFAImpl.qk_hadamard
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
# So two branches are maintained temporarily.
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
if self.use_torch_npu_lightning_indexer:
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
weights = weights.to(torch.float16)
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
query=q_li.view(q_li_shape_ori),
key=kv_cache[2],
weights=weights,
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=attn_metadata.block_table,
query_quant_mode=0,
key_quant_mode=0,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
elif self.use_torch_npu_lightning_indexer:
topk_indices, _ = torch_npu.npu_lightning_indexer(
query=q_li,
key=kv_cache[2],
@@ -1031,7 +1079,7 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
q_c = self.q_a_layernorm(q_c)
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
wait_for_kv_layer_from_connector(layer_name)
@@ -1044,20 +1092,46 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.enable_dsa_cp:
assert k_pe is not None
assert k_nope is not None
assert k_li is not None
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
# support all_gather kv async for communication calculation overlap
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
if not self.use_sparse_c8_indexer:
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
else:
# due to different dtypes, we have to split commu pass
assert k_li_scale is not None
fused_kv_no_split, _ = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
k_li, _ = all_gather_async(
k_li,
get_tp_group(),
async_op=async_op,
)
k_li_scale, kv_ag_handle = all_gather_async(
k_li_scale,
get_tp_group(),
async_op=async_op,
)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
@@ -1077,9 +1151,12 @@ class AscendSFAImpl(MLAAttentionImpl):
if kv_cache is not None:
assert fused_kv_no_split is not None
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
if not self.use_sparse_c8_indexer:
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
else:
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
DeviceOperator.reshape_and_cache(
@@ -1098,6 +1175,13 @@ class AscendSFAImpl(MLAAttentionImpl):
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
) # b, s, n, d
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
torch_npu.npu_scatter_nd_update_(
kv_cache[3].view(-1, k_li_scale.shape[-1]),
slot_mapping.view(-1, 1),
k_li_scale.view(-1, k_li_scale.shape[-1]),
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()

View File

@@ -137,6 +137,28 @@
# Remove this patch if upstream provides an official NPU graph-capture
# guidance / auto-configuration path for HCCL.
#
# ** 8. File: platform/patch_kv_cache_interface.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
# Why:
# The default `MLAAttentionSpec` is mainly built around `kv_lora_rank`
# and `qk_rope_head_dim`. On NPU, we also use this class to describe DSA
# models. Unlike the GPU path, where cache management is handled by an
# additional indexer module, extending this class directly simplifies the
# corresponding `model_runner` implementation on NPU.
#
# This patch also adds Sparse C8 support for DSA models on NPU. As part
# of that support, members such as `page_size_bytes` need to be adapted,
# so they are overridden here as well to preserve overall readability.
# How:
# This patch subclasses the original implementation, overrides selected
# methods, and adds DSA-specific attributes and helpers with default
# values where needed.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/25896
# Future Plan:
# Remove this patch after the upcoming KV cache spec refactor.
#
# * Worker Patch:
# ===============
#

View File

@@ -18,6 +18,7 @@ import os
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa

View File

@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
import vllm.v1.kv_cache_interface
from typing_extensions import Self
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@dataclass(frozen=True)
class AscendMLAAttentionSpec(MLAAttentionSpec):
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
When Sparse C8 is enabled, the KV cache tuple changes from
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
to
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
The semantic meaning of each KV cache entry is as follows:
1. kv_cache[0] stores kv_lora.
2. kv_cache[1] stores k_rope.
3. kv_cache[2] stores the key tensor from the indexer module.
4. kv_cache[3] stores the key scale tensor from the indexer module,
and exists only when Sparse C8 is enabled.
The main changes are as follows:
1. The key tensor from the indexer module stored in kv_cache[2] is
converted from bf16 to int8 to reduce memory usage. It is then
processed with int8 precision in Lightning_indexer computation
to improve computational efficiency.
2. The quantization scale of the key tensor in the indexer module
must also be stored for the Lightning_indexer_quant operator,
and is therefore saved in kv_cache[3].
"""
sparse_head_dim: tuple[int, ...] | None = None
cache_sparse_c8: bool = False
c8_k_cache_dtype: torch.dtype = torch.int8
c8_k_scale_cache_dtype: torch.dtype = torch.float16
@property
def page_size_bytes(self) -> int:
if self.cache_sparse_c8:
assert self.sparse_head_dim is not None
assert len(self.sparse_head_dim) == 3
num_heads_per_page = self.block_size * self.num_kv_heads
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
# kv_cache[2]: int8
index_head_dim = self.sparse_head_dim[-1]
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
# kv_cache[3]: float16
# since the scale is stored per token, head_dim is set to 1.
index_scale_head_dim = 1
indexer_k_scale_bytes = (
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
)
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
@property
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
"""
Compute the relative byte share of each KV cache entry.
Returns:
A tuple containing the ratios for:
- kv_cache[0]
- kv_cache[1]
- kv_cache[2]
- kv_cache[3] (None if Sparse C8 is disabled)
"""
assert self.sparse_head_dim is not None
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
assert self.sparse_head_dim is not None
assert self.cache_sparse_c8 is True
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
index_k_head_dim_virtual = index_k_head_dim // factor
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
index_k_scale_head_dim_virtual = 1
return (
kv_lora_rank,
qk_rope_head_dim,
index_k_head_dim_virtual,
index_k_scale_head_dim_virtual,
)
if self.cache_sparse_c8:
virtual_dims = get_sparse_head_dim_virtual()
total_virtual_head_dim = sum(virtual_dims)
return (
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
)
return (
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
None, # kv_cache[3] does not exist
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same quantization method."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
sparse_head_dim=specs[0].sparse_head_dim,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
cache_sparse_c8=specs[0].cache_sparse_c8,
)
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec

View File

@@ -84,6 +84,7 @@ from vllm.v1.worker.ubatch_utils import (
)
from vllm.v1.worker.utils import AttentionGroup
# yapf: enable
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention
@@ -96,8 +97,6 @@ from vllm_ascend.compilation.acl_graph import (
set_graph_params,
update_full_graph_params,
)
# yapf: enable
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
@@ -274,7 +273,21 @@ class NPUModelRunner(GPUModelRunner):
self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size
# Set up Attention
self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk")
self.use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
if self.use_sparse:
self.sparse_head_dim = (
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
)
# dsa c8
self.use_sparse_c8_indexer = self.ascend_config.enable_sparse_c8
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
self.attn_backend = get_attn_backend(
0,
self.dtype,
@@ -2623,7 +2636,7 @@ class NPUModelRunner(GPUModelRunner):
to their corresponding memory buffer for K cache and V cache.
"""
# init kv cache tensors
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {}
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {}
# prefill disaggregation need the addr of cache tensor be aligned with 2M
alignment = 2 * 1024 * 1024
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
@@ -2670,19 +2683,18 @@ class NPUModelRunner(GPUModelRunner):
+ self.model_config.hf_text_config.kv_lora_rank
)
dsa_k_cache_factor = None
dsa_k_cache_size = None
if not self.model_config.use_mla:
# for non-mla model, use FullAttentionSpec
k_tensor_split_factor = 2
v_tensor_split_factor = 2
k_tensor_split_factor = 2.0
v_tensor_split_factor = 2.0
elif self.use_sparse:
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore
sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio()
]
dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor)
kv_cache_spec = layer_kv_cache_spec[layer_name]
sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
k_tensor_split_factor = sparse_kv_cache_ratio[0]
v_tensor_split_factor = sparse_kv_cache_ratio[1]
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
else:
# for other deepseek models, use MLAAttentionSpec
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
@@ -2690,35 +2702,56 @@ class NPUModelRunner(GPUModelRunner):
k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor)
v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor)
dsa_k_tensor_size = None
dsa_k_scale_tensor_size = None
#### for deepseek sparse attention
if self.use_sparse:
dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor)
if self.use_sparse_c8_indexer:
dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor)
# for other attentions, e.g., self_attn, sliding window attn
if self.vllm_config.kv_transfer_config is None:
k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device)
v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device)
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None:
dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device)
#### for deepseek sparse attention
if dsa_k_tensor_size is not None:
dsa_k_tensor = torch.zeros(dsa_k_tensor_size, dtype=torch.int8, device=self.device)
if dsa_k_scale_tensor_size is not None:
dsa_k_scale_tensor = torch.zeros(
dsa_k_scale_tensor_size, dtype=torch.int8, device=self.device
)
else:
k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device)
v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device)
k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size]
v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size]
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
dsa_k_cache_tensor = torch.zeros(
dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device
#### for deepseek sparse attention
if dsa_k_tensor_size is not None:
dsa_k_tensor = torch.zeros(
dsa_k_tensor_size + alignment, dtype=torch.int8, device=self.device
)
dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size]
dsa_k_tensor = self._align_memory(dsa_k_tensor, alignment)[:dsa_k_tensor_size]
if dsa_k_scale_tensor_size is not None:
dsa_k_scale_tensor = torch.zeros(
dsa_k_scale_tensor_size + alignment, dtype=torch.int8, device=self.device
)
dsa_k_scale_tensor = self._align_memory(
dsa_k_scale_tensor, alignment
)[:dsa_k_scale_tensor_size]
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the attn kvcache for all shared layers
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
kv_cache_raw_tensors[layer_name_inner] = (
(k_tensor, v_tensor)
if not self.use_sparse
else (k_tensor, v_tensor, dsa_k_cache_tensor)
)
if self.use_sparse:
if self.use_sparse_c8_indexer:
kv_cache_raw_tensors[layer_name_inner] = (
k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor
)
else:
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor, dsa_k_tensor)
else:
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor)
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
@@ -2760,13 +2793,23 @@ class NPUModelRunner(GPUModelRunner):
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, AttentionSpec):
raw_dsa_k_tensor = None
if self.use_sparse:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name
]
assert raw_dsa_k_tensor is not None
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
if self.use_sparse_c8_indexer:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_dsa_k_tensor is not None
assert raw_dsa_k_scale_tensor is not None
sum_page_size_bytes = (
raw_k_tensor.numel()
+ raw_v_tensor.numel()
+ raw_dsa_k_tensor.numel()
+ raw_dsa_k_scale_tensor.numel()
)
else:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_dsa_k_tensor is not None
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba:
# Currently, we ensure that the same kvcache format is used even if there
# is no shared layer, such as the full attention mtp layer of qwen3.5, etc.
@@ -2813,7 +2856,7 @@ class NPUModelRunner(GPUModelRunner):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
)
dtype = kv_cache_spec.dtype
if not self.model_config.use_mla:
k_shape = kv_cache_shape[1:]
v_shape = k_shape
@@ -2832,19 +2875,37 @@ class NPUModelRunner(GPUModelRunner):
num_kv_heads,
self.model_config.hf_text_config.qk_rope_head_dim,
]
k_cache = raw_k_tensor.view(dtype).view(k_shape)
v_cache = raw_v_tensor.view(dtype).view(v_shape)
k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape)
v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape)
if self.use_sparse and raw_dsa_k_tensor is not None:
index_head_dim = self._get_sparse_kv_cache_ratio()[-1]
if self.use_sparse:
dsa_k_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
index_head_dim,
self.model_config.hf_text_config.index_head_dim,
)
dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
if self.use_sparse_c8_indexer:
# dsa_k
dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape)
# dsa_k_scale
dsa_k_scale_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
1,
)
assert raw_dsa_k_scale_tensor is not None
dsa_k_scale_cache = (
raw_dsa_k_scale_tensor
.view(self.c8_k_scale_cache_dtype)
.view(dsa_k_scale_cache_shape)
)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache)
else:
# dsa_k
dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
else:
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
@@ -3098,18 +3159,31 @@ class NPUModelRunner(GPUModelRunner):
elif isinstance(attn_module, MLAAttention):
if self.use_sparse:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is the final way.
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
kv_cache_spec[layer_name] = MLAAttentionSpec(
# `MLAAttentionSpec` is temporarily patched to `AscendMLAAttentionSpec`.
# Re-importing it at runtime will therefore resolve to the patched class.
# Rename it here to make this behavior explicit.
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
# TODO(rjg-lyh): when kv_cache_spec's refactor is ready,
# implement it by creating a new kv_cache_spec class
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=sparse_sum_head_size,
head_size=sum(self.sparse_head_dim),
sparse_head_dim=self.sparse_head_dim,
dtype=self.kv_cache_dtype,
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
cache_sparse_c8=self.use_sparse_c8_indexer,
)
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
kv_cache_spec[layer_name] = spec
assert isinstance(spec, MLAAttentionSpec)
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
cache_dtype_str=spec.cache_dtype_str,
)
elif isinstance(attn_module, MambaBase):
mamba_layers[layer_name] = attn_module
@@ -3129,16 +3203,6 @@ class NPUModelRunner(GPUModelRunner):
return kv_cache_spec
def _get_sparse_kv_cache_ratio(self) -> list[int]:
# TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes
# when calculating the ratiofor example:
# [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...]
return [
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
]
def _check_and_update_cudagraph_mode(
self,
attention_backends: list[set[type[AttentionBackend]]],