Revert "[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)" (#7288)

### What this PR does / why we need it?
This reverts commit 7ed9e9de69, which
introduces an issue that the patch doesn't work with recompute scheduler
enabled.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2026-03-15 20:19:09 +08:00
committed by GitHub
parent 29f195a91c
commit 0c299f79b9
24 changed files with 79 additions and 4281 deletions

View File

@@ -134,12 +134,9 @@ class AscendConfig:
bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant()
)
use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
if self.enable_kv_nz:
use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
raise RuntimeError("enable_kv_nz is only supported for mla currently.")
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
@@ -147,17 +144,6 @@ class AscendConfig:
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
# Disable Sparse C8 for A5
# A5 has not been fully validated for this path and may carry hidden risks.
# TODO(rjg-lyh): Enable A5 support after sufficient validation.
self.enable_sparse_c8 = (
additional_config.get("enable_sparse_c8", False)
and use_sparse
and get_ascend_device_type() != AscendDeviceType.A5
)
def _construct_weight_prefetch_config(self, additional_config):
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)

View File

@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeVar
import scipy # type: ignore
import torch
import torch_npu
import vllm.envs as envs_vllm
@@ -356,9 +355,6 @@ class AscendSFAImpl(MLAAttentionImpl):
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
o_proj_full_pool: torch.Tensor | None = None
# qk_hadamard tensor shared when dsa c8 enabled
qk_hadamard: torch.Tensor | None = None
def __init__(
self,
num_heads: int,
@@ -429,12 +425,6 @@ class AscendSFAImpl(MLAAttentionImpl):
self.is_rope_neox_style = False
self.use_torch_npu_lightning_indexer = True
# dsa c8
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
# Effective in SFA when FlashComm is enabled.
self.enable_dsa_cp = enable_dsa_cp()
@@ -525,11 +515,6 @@ class AscendSFAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
128**0.5
)
# Processing the input parameters for MLAPO by reordering and transposing
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
# and handling quantization parameters.
@@ -889,15 +874,7 @@ class AscendSFAImpl(MLAAttentionImpl):
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
if self.use_sparse_c8_indexer:
k_li = k_li @ AscendSFAImpl.qk_hadamard
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
else:
k_li_scale = None
return k_li, k_li_scale
return k_li
def indexer_select_post_process(
self,
@@ -928,35 +905,10 @@ class AscendSFAImpl(MLAAttentionImpl):
q_li_pe = q_li_pe.squeeze(2)
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
if self.use_sparse_c8_indexer:
q_li_shape_ori = q_li.shape
q_li = q_li @ AscendSFAImpl.qk_hadamard
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
# So two branches are maintained temporarily.
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
weights = weights.to(torch.float16)
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
query=q_li.view(q_li_shape_ori),
key=kv_cache[2],
weights=weights,
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=attn_metadata.block_table,
query_quant_mode=0,
key_quant_mode=0,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
elif self.use_torch_npu_lightning_indexer:
if self.use_torch_npu_lightning_indexer:
topk_indices, _ = torch_npu.npu_lightning_indexer(
query=q_li,
key=kv_cache[2],
@@ -1079,7 +1031,7 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
q_c = self.q_a_layernorm(q_c)
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
wait_for_kv_layer_from_connector(layer_name)
@@ -1092,46 +1044,20 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.enable_dsa_cp:
assert k_pe is not None
assert k_nope is not None
assert k_li is not None
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
# support all_gather kv async for communication calculation overlap
if not self.use_sparse_c8_indexer:
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
else:
# due to different dtypes, we have to split commu pass
assert k_li_scale is not None
fused_kv_no_split, _ = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
k_li, _ = all_gather_async(
k_li,
get_tp_group(),
async_op=async_op,
)
k_li_scale, kv_ag_handle = all_gather_async(
k_li_scale,
get_tp_group(),
async_op=async_op,
)
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
@@ -1151,12 +1077,9 @@ class AscendSFAImpl(MLAAttentionImpl):
if kv_cache is not None:
assert fused_kv_no_split is not None
if not self.use_sparse_c8_indexer:
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
else:
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
DeviceOperator.reshape_and_cache(
@@ -1175,13 +1098,6 @@ class AscendSFAImpl(MLAAttentionImpl):
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
) # b, s, n, d
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
torch_npu.npu_scatter_nd_update_(
kv_cache[3].view(-1, k_li_scale.shape[-1]),
slot_mapping.view(-1, 1),
k_li_scale.view(-1, k_li_scale.shape[-1]),
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()

View File

@@ -137,28 +137,6 @@
# Remove this patch if upstream provides an official NPU graph-capture
# guidance / auto-configuration path for HCCL.
#
# ** 8. File: platform/patch_kv_cache_interface.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
# Why:
# The default `MLAAttentionSpec` is mainly built around `kv_lora_rank`
# and `qk_rope_head_dim`. On NPU, we also use this class to describe DSA
# models. Unlike the GPU path, where cache management is handled by an
# additional indexer module, extending this class directly simplifies the
# corresponding `model_runner` implementation on NPU.
#
# This patch also adds Sparse C8 support for DSA models on NPU. As part
# of that support, members such as `page_size_bytes` need to be adapted,
# so they are overridden here as well to preserve overall readability.
# How:
# This patch subclasses the original implementation, overrides selected
# methods, and adds DSA-specific attributes and helpers with default
# values where needed.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/25896
# Future Plan:
# Remove this patch after the upcoming KV cache spec refactor.
#
# * Worker Patch:
# ===============
#

View File

@@ -18,7 +18,6 @@ import os
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa

View File

@@ -1,138 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
import vllm.v1.kv_cache_interface
from typing_extensions import Self
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@dataclass(frozen=True)
class AscendMLAAttentionSpec(MLAAttentionSpec):
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
When Sparse C8 is enabled, the KV cache tuple changes from
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
to
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
The semantic meaning of each KV cache entry is as follows:
1. kv_cache[0] stores kv_lora.
2. kv_cache[1] stores k_rope.
3. kv_cache[2] stores the key tensor from the indexer module.
4. kv_cache[3] stores the key scale tensor from the indexer module,
and exists only when Sparse C8 is enabled.
The main changes are as follows:
1. The key tensor from the indexer module stored in kv_cache[2] is
converted from bf16 to int8 to reduce memory usage. It is then
processed with int8 precision in Lightning_indexer computation
to improve computational efficiency.
2. The quantization scale of the key tensor in the indexer module
must also be stored for the Lightning_indexer_quant operator,
and is therefore saved in kv_cache[3].
"""
sparse_head_dim: tuple[int, ...] | None = None
cache_sparse_c8: bool = False
c8_k_cache_dtype: torch.dtype = torch.int8
c8_k_scale_cache_dtype: torch.dtype = torch.float16
@property
def page_size_bytes(self) -> int:
if self.cache_sparse_c8:
assert self.sparse_head_dim is not None
assert len(self.sparse_head_dim) == 3
num_heads_per_page = self.block_size * self.num_kv_heads
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
# kv_cache[2]: int8
index_head_dim = self.sparse_head_dim[-1]
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
# kv_cache[3]: float16
# since the scale is stored per token, head_dim is set to 1.
index_scale_head_dim = 1
indexer_k_scale_bytes = (
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
)
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
@property
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
"""
Compute the relative byte share of each KV cache entry.
Returns:
A tuple containing the ratios for:
- kv_cache[0]
- kv_cache[1]
- kv_cache[2]
- kv_cache[3] (None if Sparse C8 is disabled)
"""
assert self.sparse_head_dim is not None
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
assert self.sparse_head_dim is not None
assert self.cache_sparse_c8 is True
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
index_k_head_dim_virtual = index_k_head_dim // factor
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
index_k_scale_head_dim_virtual = 1
return (
kv_lora_rank,
qk_rope_head_dim,
index_k_head_dim_virtual,
index_k_scale_head_dim_virtual,
)
if self.cache_sparse_c8:
virtual_dims = get_sparse_head_dim_virtual()
total_virtual_head_dim = sum(virtual_dims)
return (
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
)
return (
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
None, # kv_cache[3] does not exist
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same quantization method."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
sparse_head_dim=specs[0].sparse_head_dim,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
cache_sparse_c8=specs[0].cache_sparse_c8,
)
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec

View File

@@ -88,7 +88,6 @@ from vllm.v1.worker.ubatch_utils import (
)
from vllm.v1.worker.utils import AttentionGroup
# yapf: enable
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention
@@ -101,6 +100,8 @@ from vllm_ascend.compilation.acl_graph import (
set_graph_params,
update_full_graph_params,
)
# yapf: enable
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
@@ -277,21 +278,7 @@ class NPUModelRunner(GPUModelRunner):
self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size
# Set up Attention
self.use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
if self.use_sparse:
self.sparse_head_dim = (
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
)
# dsa c8
self.use_sparse_c8_indexer = self.ascend_config.enable_sparse_c8
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk")
self.attn_backend = get_attn_backend(
0,
self.dtype,
@@ -2642,7 +2629,7 @@ class NPUModelRunner(GPUModelRunner):
to their corresponding memory buffer for K cache and V cache.
"""
# init kv cache tensors
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {}
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {}
# prefill disaggregation need the addr of cache tensor be aligned with 2M
alignment = 2 * 1024 * 1024
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
@@ -2689,18 +2676,19 @@ class NPUModelRunner(GPUModelRunner):
+ self.model_config.hf_text_config.kv_lora_rank
)
dsa_k_cache_factor = None
dsa_k_cache_size = None
if not self.model_config.use_mla:
# for non-mla model, use FullAttentionSpec
k_tensor_split_factor = 2.0
v_tensor_split_factor = 2.0
k_tensor_split_factor = 2
v_tensor_split_factor = 2
elif self.use_sparse:
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
kv_cache_spec = layer_kv_cache_spec[layer_name]
sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
k_tensor_split_factor = sparse_kv_cache_ratio[0]
v_tensor_split_factor = sparse_kv_cache_ratio[1]
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore
sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio()
]
dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor)
else:
# for other deepseek models, use MLAAttentionSpec
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
@@ -2708,56 +2696,35 @@ class NPUModelRunner(GPUModelRunner):
k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor)
v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor)
dsa_k_tensor_size = None
dsa_k_scale_tensor_size = None
#### for deepseek sparse attention
if self.use_sparse:
dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor)
if self.use_sparse_c8_indexer:
dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor)
# for other attentions, e.g., self_attn, sliding window attn
if self.vllm_config.kv_transfer_config is None:
k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device)
v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device)
#### for deepseek sparse attention
if dsa_k_tensor_size is not None:
dsa_k_tensor = torch.zeros(dsa_k_tensor_size, dtype=torch.int8, device=self.device)
if dsa_k_scale_tensor_size is not None:
dsa_k_scale_tensor = torch.zeros(
dsa_k_scale_tensor_size, dtype=torch.int8, device=self.device
)
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None:
dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device)
else:
k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device)
v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device)
k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size]
v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size]
#### for deepseek sparse attention
if dsa_k_tensor_size is not None:
dsa_k_tensor = torch.zeros(
dsa_k_tensor_size + alignment, dtype=torch.int8, device=self.device
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
dsa_k_cache_tensor = torch.zeros(
dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device
)
dsa_k_tensor = self._align_memory(dsa_k_tensor, alignment)[:dsa_k_tensor_size]
if dsa_k_scale_tensor_size is not None:
dsa_k_scale_tensor = torch.zeros(
dsa_k_scale_tensor_size + alignment, dtype=torch.int8, device=self.device
)
dsa_k_scale_tensor = self._align_memory(
dsa_k_scale_tensor, alignment
)[:dsa_k_scale_tensor_size]
dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size]
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the attn kvcache for all shared layers
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
if self.use_sparse:
if self.use_sparse_c8_indexer:
kv_cache_raw_tensors[layer_name_inner] = (
k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor
)
else:
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor, dsa_k_tensor)
else:
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor)
kv_cache_raw_tensors[layer_name_inner] = (
(k_tensor, v_tensor)
if not self.use_sparse
else (k_tensor, v_tensor, dsa_k_cache_tensor)
)
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
@@ -2799,23 +2766,13 @@ class NPUModelRunner(GPUModelRunner):
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, AttentionSpec):
raw_dsa_k_tensor = None
if self.use_sparse:
if self.use_sparse_c8_indexer:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_dsa_k_tensor is not None
assert raw_dsa_k_scale_tensor is not None
sum_page_size_bytes = (
raw_k_tensor.numel()
+ raw_v_tensor.numel()
+ raw_dsa_k_tensor.numel()
+ raw_dsa_k_scale_tensor.numel()
)
else:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_dsa_k_tensor is not None
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name
]
assert raw_dsa_k_tensor is not None
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba:
# Currently, we ensure that the same kvcache format is used even if there
# is no shared layer, such as the full attention mtp layer of qwen3.5, etc.
@@ -2862,7 +2819,7 @@ class NPUModelRunner(GPUModelRunner):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
)
dtype = kv_cache_spec.dtype
if not self.model_config.use_mla:
k_shape = kv_cache_shape[1:]
v_shape = k_shape
@@ -2881,37 +2838,19 @@ class NPUModelRunner(GPUModelRunner):
num_kv_heads,
self.model_config.hf_text_config.qk_rope_head_dim,
]
k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape)
v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape)
k_cache = raw_k_tensor.view(dtype).view(k_shape)
v_cache = raw_v_tensor.view(dtype).view(v_shape)
if self.use_sparse:
if self.use_sparse and raw_dsa_k_tensor is not None:
index_head_dim = self._get_sparse_kv_cache_ratio()[-1]
dsa_k_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
self.model_config.hf_text_config.index_head_dim,
index_head_dim,
)
if self.use_sparse_c8_indexer:
# dsa_k
dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape)
# dsa_k_scale
dsa_k_scale_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
1,
)
assert raw_dsa_k_scale_tensor is not None
dsa_k_scale_cache = (
raw_dsa_k_scale_tensor
.view(self.c8_k_scale_cache_dtype)
.view(dsa_k_scale_cache_shape)
)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache)
else:
# dsa_k
dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
else:
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
@@ -3007,7 +2946,7 @@ class NPUModelRunner(GPUModelRunner):
# of mamba block. In this case, BlockTable.block_size will never equal
# to kernel_block_sizes[0]
self.kernel_block_sizes.append([0])
max_num_blocks = []
max_model_len = max(self.max_model_len, self.max_encoder_len)
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
@@ -3021,7 +2960,7 @@ class NPUModelRunner(GPUModelRunner):
max_num_blocks_per_req = max(max_num_blocks_per_req, mamba_blocks_per_req)
max_num_blocks.append(max_num_blocks_per_req)
if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
@@ -3181,31 +3120,18 @@ class NPUModelRunner(GPUModelRunner):
elif isinstance(attn_module, MLAAttention):
if self.use_sparse:
# `MLAAttentionSpec` is temporarily patched to `AscendMLAAttentionSpec`.
# Re-importing it at runtime will therefore resolve to the patched class.
# Rename it here to make this behavior explicit.
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
# TODO(rjg-lyh): when kv_cache_spec's refactor is ready,
# implement it by creating a new kv_cache_spec class
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is the final way.
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=sum(self.sparse_head_dim),
sparse_head_dim=self.sparse_head_dim,
head_size=sparse_sum_head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
cache_sparse_c8=self.use_sparse_c8_indexer,
)
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
assert isinstance(spec, MLAAttentionSpec)
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
cache_dtype_str=spec.cache_dtype_str,
)
kv_cache_spec[layer_name] = spec
elif isinstance(attn_module, MambaBase):
mamba_layers[layer_name] = attn_module
@@ -3223,6 +3149,16 @@ class NPUModelRunner(GPUModelRunner):
return kv_cache_spec
def _get_sparse_kv_cache_ratio(self) -> list[int]:
# TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes
# when calculating the ratiofor example:
# [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...]
return [
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
]
def _check_and_update_cudagraph_mode(
self,
attention_backends: list[set[type[AttentionBackend]]],