[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)
### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.
Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.
The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -134,9 +134,12 @@ class AscendConfig:
|
||||
bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant()
|
||||
)
|
||||
|
||||
use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
|
||||
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
|
||||
if self.enable_kv_nz:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
|
||||
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
|
||||
raise RuntimeError("enable_kv_nz is only supported for mla currently.")
|
||||
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
@@ -144,6 +147,17 @@ class AscendConfig:
|
||||
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
# Disable Sparse C8 for A5
|
||||
# A5 has not been fully validated for this path and may carry hidden risks.
|
||||
# TODO(rjg-lyh): Enable A5 support after sufficient validation.
|
||||
self.enable_sparse_c8 = (
|
||||
additional_config.get("enable_sparse_c8", False)
|
||||
and use_sparse
|
||||
and get_ascend_device_type() != AscendDeviceType.A5
|
||||
)
|
||||
|
||||
def _construct_weight_prefetch_config(self, additional_config):
|
||||
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import scipy # type: ignore
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
@@ -355,6 +356,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
|
||||
o_proj_full_pool: torch.Tensor | None = None
|
||||
|
||||
# qk_hadamard tensor shared when dsa c8 enabled
|
||||
qk_hadamard: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -425,6 +429,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.is_rope_neox_style = False
|
||||
self.use_torch_npu_lightning_indexer = True
|
||||
|
||||
# dsa c8
|
||||
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
|
||||
if self.use_sparse_c8_indexer:
|
||||
self.c8_k_cache_dtype = torch.int8
|
||||
self.c8_k_scale_cache_dtype = torch.float16
|
||||
|
||||
# Effective in SFA when FlashComm is enabled.
|
||||
self.enable_dsa_cp = enable_dsa_cp()
|
||||
|
||||
@@ -515,6 +525,11 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
|
||||
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
|
||||
128**0.5
|
||||
)
|
||||
|
||||
# Processing the input parameters for MLAPO by reordering and transposing
|
||||
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
|
||||
# and handling quantization parameters.
|
||||
@@ -874,7 +889,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
|
||||
|
||||
return k_li
|
||||
if self.use_sparse_c8_indexer:
|
||||
k_li = k_li @ AscendSFAImpl.qk_hadamard
|
||||
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
||||
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
|
||||
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
|
||||
else:
|
||||
k_li_scale = None
|
||||
|
||||
return k_li, k_li_scale
|
||||
|
||||
def indexer_select_post_process(
|
||||
self,
|
||||
@@ -905,10 +928,35 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
q_li_pe = q_li_pe.squeeze(2)
|
||||
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
if self.use_sparse_c8_indexer:
|
||||
q_li_shape_ori = q_li.shape
|
||||
q_li = q_li @ AscendSFAImpl.qk_hadamard
|
||||
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
||||
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
|
||||
|
||||
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
|
||||
# So two branches are maintained temporarily.
|
||||
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
|
||||
if self.use_torch_npu_lightning_indexer:
|
||||
if self.use_sparse_c8_indexer:
|
||||
assert len(kv_cache) == 4
|
||||
weights = weights.to(torch.float16)
|
||||
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
|
||||
query=q_li.view(q_li_shape_ori),
|
||||
key=kv_cache[2],
|
||||
weights=weights,
|
||||
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
|
||||
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=attn_metadata.block_table,
|
||||
query_quant_mode=0,
|
||||
key_quant_mode=0,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3,
|
||||
)
|
||||
elif self.use_torch_npu_lightning_indexer:
|
||||
topk_indices, _ = torch_npu.npu_lightning_indexer(
|
||||
query=q_li,
|
||||
key=kv_cache[2],
|
||||
@@ -1031,7 +1079,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
|
||||
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
||||
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
||||
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
@@ -1044,20 +1092,46 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if self.enable_dsa_cp:
|
||||
assert k_pe is not None
|
||||
assert k_nope is not None
|
||||
assert k_li is not None
|
||||
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
|
||||
# support all_gather kv async for communication calculation overlap
|
||||
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
k_li.view(-1, k_li.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
if not self.use_sparse_c8_indexer:
|
||||
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
k_li.view(-1, k_li.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
else:
|
||||
# due to different dtypes, we have to split commu pass
|
||||
assert k_li_scale is not None
|
||||
fused_kv_no_split, _ = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
k_li, _ = all_gather_async(
|
||||
k_li,
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
k_li_scale, kv_ag_handle = all_gather_async(
|
||||
k_li_scale,
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
@@ -1077,9 +1151,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
if kv_cache is not None:
|
||||
assert fused_kv_no_split is not None
|
||||
k_pe, k_nope, k_li = fused_kv_no_split.split(
|
||||
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
|
||||
)
|
||||
if not self.use_sparse_c8_indexer:
|
||||
k_pe, k_nope, k_li = fused_kv_no_split.split(
|
||||
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
|
||||
)
|
||||
else:
|
||||
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
|
||||
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
|
||||
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
|
||||
DeviceOperator.reshape_and_cache(
|
||||
@@ -1098,6 +1175,13 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
|
||||
) # b, s, n, d
|
||||
if self.use_sparse_c8_indexer:
|
||||
assert len(kv_cache) == 4
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[3].view(-1, k_li_scale.shape[-1]),
|
||||
slot_mapping.view(-1, 1),
|
||||
k_li_scale.view(-1, k_li_scale.shape[-1]),
|
||||
)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
|
||||
|
||||
@@ -137,6 +137,28 @@
|
||||
# Remove this patch if upstream provides an official NPU graph-capture
|
||||
# guidance / auto-configuration path for HCCL.
|
||||
#
|
||||
# ** 8. File: platform/patch_kv_cache_interface.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
|
||||
# Why:
|
||||
# The default `MLAAttentionSpec` is mainly built around `kv_lora_rank`
|
||||
# and `qk_rope_head_dim`. On NPU, we also use this class to describe DSA
|
||||
# models. Unlike the GPU path, where cache management is handled by an
|
||||
# additional indexer module, extending this class directly simplifies the
|
||||
# corresponding `model_runner` implementation on NPU.
|
||||
#
|
||||
# This patch also adds Sparse C8 support for DSA models on NPU. As part
|
||||
# of that support, members such as `page_size_bytes` need to be adapted,
|
||||
# so they are overridden here as well to preserve overall readability.
|
||||
# How:
|
||||
# This patch subclasses the original implementation, overrides selected
|
||||
# methods, and adds DSA-specific attributes and helpers with default
|
||||
# values where needed.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/25896
|
||||
# Future Plan:
|
||||
# Remove this patch after the upcoming KV cache spec refactor.
|
||||
#
|
||||
# * Worker Patch:
|
||||
# ===============
|
||||
#
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
|
||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
|
||||
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
|
||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||
|
||||
138
vllm_ascend/patch/platform/patch_kv_cache_interface.py
Normal file
138
vllm_ascend/patch/platform/patch_kv_cache_interface.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import vllm.v1.kv_cache_interface
|
||||
from typing_extensions import Self
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AscendMLAAttentionSpec(MLAAttentionSpec):
|
||||
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
|
||||
|
||||
When Sparse C8 is enabled, the KV cache tuple changes from
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
|
||||
to
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
|
||||
|
||||
The semantic meaning of each KV cache entry is as follows:
|
||||
1. kv_cache[0] stores kv_lora.
|
||||
2. kv_cache[1] stores k_rope.
|
||||
3. kv_cache[2] stores the key tensor from the indexer module.
|
||||
4. kv_cache[3] stores the key scale tensor from the indexer module,
|
||||
and exists only when Sparse C8 is enabled.
|
||||
|
||||
The main changes are as follows:
|
||||
1. The key tensor from the indexer module stored in kv_cache[2] is
|
||||
converted from bf16 to int8 to reduce memory usage. It is then
|
||||
processed with int8 precision in Lightning_indexer computation
|
||||
to improve computational efficiency.
|
||||
2. The quantization scale of the key tensor in the indexer module
|
||||
must also be stored for the Lightning_indexer_quant operator,
|
||||
and is therefore saved in kv_cache[3].
|
||||
"""
|
||||
|
||||
sparse_head_dim: tuple[int, ...] | None = None
|
||||
cache_sparse_c8: bool = False
|
||||
c8_k_cache_dtype: torch.dtype = torch.int8
|
||||
c8_k_scale_cache_dtype: torch.dtype = torch.float16
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
if self.cache_sparse_c8:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert len(self.sparse_head_dim) == 3
|
||||
num_heads_per_page = self.block_size * self.num_kv_heads
|
||||
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
|
||||
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
|
||||
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
|
||||
# kv_cache[2]: int8
|
||||
index_head_dim = self.sparse_head_dim[-1]
|
||||
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
|
||||
# kv_cache[3]: float16
|
||||
# since the scale is stored per token, head_dim is set to 1.
|
||||
index_scale_head_dim = 1
|
||||
indexer_k_scale_bytes = (
|
||||
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
)
|
||||
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
|
||||
|
||||
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
|
||||
|
||||
@property
|
||||
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
|
||||
"""
|
||||
Compute the relative byte share of each KV cache entry.
|
||||
|
||||
Returns:
|
||||
A tuple containing the ratios for:
|
||||
- kv_cache[0]
|
||||
- kv_cache[1]
|
||||
- kv_cache[2]
|
||||
- kv_cache[3] (None if Sparse C8 is disabled)
|
||||
"""
|
||||
|
||||
assert self.sparse_head_dim is not None
|
||||
|
||||
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert self.cache_sparse_c8 is True
|
||||
|
||||
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
|
||||
|
||||
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
|
||||
index_k_head_dim_virtual = index_k_head_dim // factor
|
||||
|
||||
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
index_k_scale_head_dim_virtual = 1
|
||||
|
||||
return (
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
index_k_head_dim_virtual,
|
||||
index_k_scale_head_dim_virtual,
|
||||
)
|
||||
|
||||
if self.cache_sparse_c8:
|
||||
virtual_dims = get_sparse_head_dim_virtual()
|
||||
total_virtual_head_dim = sum(virtual_dims)
|
||||
|
||||
return (
|
||||
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
|
||||
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
|
||||
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
|
||||
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
|
||||
)
|
||||
|
||||
return (
|
||||
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
|
||||
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
|
||||
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
|
||||
None, # kv_cache[3] does not exist
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be MLAAttentionSpec."
|
||||
)
|
||||
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
||||
assert len(cache_dtype_str_set) == 1, (
|
||||
"All attention layers in the same KV cache group must use the same quantization method."
|
||||
)
|
||||
return cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
sparse_head_dim=specs[0].sparse_head_dim,
|
||||
dtype=specs[0].dtype,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
cache_sparse_c8=specs[0].cache_sparse_c8,
|
||||
)
|
||||
|
||||
|
||||
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec
|
||||
@@ -84,6 +84,7 @@ from vllm.v1.worker.ubatch_utils import (
|
||||
)
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
# yapf: enable
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention
|
||||
@@ -96,8 +97,6 @@ from vllm_ascend.compilation.acl_graph import (
|
||||
set_graph_params,
|
||||
update_full_graph_params,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
|
||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
@@ -274,7 +273,21 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# Set up Attention
|
||||
self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk")
|
||||
self.use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
if self.use_sparse:
|
||||
self.sparse_head_dim = (
|
||||
self.model_config.hf_text_config.kv_lora_rank,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
)
|
||||
# dsa c8
|
||||
self.use_sparse_c8_indexer = self.ascend_config.enable_sparse_c8
|
||||
if self.use_sparse_c8_indexer:
|
||||
self.c8_k_cache_dtype = torch.int8
|
||||
self.c8_k_scale_cache_dtype = torch.float16
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
0,
|
||||
self.dtype,
|
||||
@@ -2623,7 +2636,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
to their corresponding memory buffer for K cache and V cache.
|
||||
"""
|
||||
# init kv cache tensors
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {}
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {}
|
||||
# prefill disaggregation need the addr of cache tensor be aligned with 2M
|
||||
alignment = 2 * 1024 * 1024
|
||||
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
@@ -2670,19 +2683,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
+ self.model_config.hf_text_config.kv_lora_rank
|
||||
)
|
||||
|
||||
dsa_k_cache_factor = None
|
||||
dsa_k_cache_size = None
|
||||
if not self.model_config.use_mla:
|
||||
# for non-mla model, use FullAttentionSpec
|
||||
k_tensor_split_factor = 2
|
||||
v_tensor_split_factor = 2
|
||||
k_tensor_split_factor = 2.0
|
||||
v_tensor_split_factor = 2.0
|
||||
elif self.use_sparse:
|
||||
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
|
||||
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
|
||||
k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore
|
||||
sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio()
|
||||
]
|
||||
dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor)
|
||||
kv_cache_spec = layer_kv_cache_spec[layer_name]
|
||||
sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
|
||||
k_tensor_split_factor = sparse_kv_cache_ratio[0]
|
||||
v_tensor_split_factor = sparse_kv_cache_ratio[1]
|
||||
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
|
||||
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
|
||||
else:
|
||||
# for other deepseek models, use MLAAttentionSpec
|
||||
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
|
||||
@@ -2690,35 +2702,56 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor)
|
||||
v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor)
|
||||
dsa_k_tensor_size = None
|
||||
dsa_k_scale_tensor_size = None
|
||||
#### for deepseek sparse attention
|
||||
if self.use_sparse:
|
||||
dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor)
|
||||
if self.use_sparse_c8_indexer:
|
||||
dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor)
|
||||
|
||||
# for other attentions, e.g., self_attn, sliding window attn
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device)
|
||||
v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device)
|
||||
#### k cache: for deepseek sparse attention
|
||||
if dsa_k_cache_factor is not None:
|
||||
dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device)
|
||||
#### for deepseek sparse attention
|
||||
if dsa_k_tensor_size is not None:
|
||||
dsa_k_tensor = torch.zeros(dsa_k_tensor_size, dtype=torch.int8, device=self.device)
|
||||
if dsa_k_scale_tensor_size is not None:
|
||||
dsa_k_scale_tensor = torch.zeros(
|
||||
dsa_k_scale_tensor_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
else:
|
||||
k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device)
|
||||
v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device)
|
||||
k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size]
|
||||
v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size]
|
||||
#### k cache: for deepseek sparse attention
|
||||
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
|
||||
dsa_k_cache_tensor = torch.zeros(
|
||||
dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device
|
||||
#### for deepseek sparse attention
|
||||
if dsa_k_tensor_size is not None:
|
||||
dsa_k_tensor = torch.zeros(
|
||||
dsa_k_tensor_size + alignment, dtype=torch.int8, device=self.device
|
||||
)
|
||||
dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size]
|
||||
dsa_k_tensor = self._align_memory(dsa_k_tensor, alignment)[:dsa_k_tensor_size]
|
||||
if dsa_k_scale_tensor_size is not None:
|
||||
dsa_k_scale_tensor = torch.zeros(
|
||||
dsa_k_scale_tensor_size + alignment, dtype=torch.int8, device=self.device
|
||||
)
|
||||
dsa_k_scale_tensor = self._align_memory(
|
||||
dsa_k_scale_tensor, alignment
|
||||
)[:dsa_k_scale_tensor_size]
|
||||
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
# shared the attn kvcache for all shared layers
|
||||
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (
|
||||
(k_tensor, v_tensor)
|
||||
if not self.use_sparse
|
||||
else (k_tensor, v_tensor, dsa_k_cache_tensor)
|
||||
)
|
||||
|
||||
if self.use_sparse:
|
||||
if self.use_sparse_c8_indexer:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (
|
||||
k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor
|
||||
)
|
||||
else:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor, dsa_k_tensor)
|
||||
else:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor)
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
@@ -2760,13 +2793,23 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
raw_dsa_k_tensor = None
|
||||
if self.use_sparse:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name
|
||||
]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
|
||||
if self.use_sparse_c8_indexer:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
assert raw_dsa_k_scale_tensor is not None
|
||||
sum_page_size_bytes = (
|
||||
raw_k_tensor.numel()
|
||||
+ raw_v_tensor.numel()
|
||||
+ raw_dsa_k_tensor.numel()
|
||||
+ raw_dsa_k_scale_tensor.numel()
|
||||
)
|
||||
else:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
|
||||
elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba:
|
||||
# Currently, we ensure that the same kvcache format is used even if there
|
||||
# is no shared layer, such as the full attention mtp layer of qwen3.5, etc.
|
||||
@@ -2813,7 +2856,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
|
||||
)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
if not self.model_config.use_mla:
|
||||
k_shape = kv_cache_shape[1:]
|
||||
v_shape = k_shape
|
||||
@@ -2832,19 +2875,37 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_kv_heads,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
]
|
||||
k_cache = raw_k_tensor.view(dtype).view(k_shape)
|
||||
v_cache = raw_v_tensor.view(dtype).view(v_shape)
|
||||
k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape)
|
||||
v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape)
|
||||
|
||||
if self.use_sparse and raw_dsa_k_tensor is not None:
|
||||
index_head_dim = self._get_sparse_kv_cache_ratio()[-1]
|
||||
if self.use_sparse:
|
||||
dsa_k_cache_shape = (
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
index_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
)
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||||
if self.use_sparse_c8_indexer:
|
||||
# dsa_k
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape)
|
||||
# dsa_k_scale
|
||||
dsa_k_scale_cache_shape = (
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
1,
|
||||
)
|
||||
assert raw_dsa_k_scale_tensor is not None
|
||||
dsa_k_scale_cache = (
|
||||
raw_dsa_k_scale_tensor
|
||||
.view(self.c8_k_scale_cache_dtype)
|
||||
.view(dsa_k_scale_cache_shape)
|
||||
)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache)
|
||||
else:
|
||||
# dsa_k
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||||
else:
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
@@ -3098,18 +3159,31 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
if self.use_sparse:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is the final way.
|
||||
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
# `MLAAttentionSpec` is temporarily patched to `AscendMLAAttentionSpec`.
|
||||
# Re-importing it at runtime will therefore resolve to the patched class.
|
||||
# Rename it here to make this behavior explicit.
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
|
||||
# TODO(rjg-lyh): when kv_cache_spec's refactor is ready,
|
||||
# implement it by creating a new kv_cache_spec class
|
||||
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=sparse_sum_head_size,
|
||||
head_size=sum(self.sparse_head_dim),
|
||||
sparse_head_dim=self.sparse_head_dim,
|
||||
dtype=self.kv_cache_dtype,
|
||||
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
|
||||
cache_sparse_c8=self.use_sparse_c8_indexer,
|
||||
)
|
||||
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
assert isinstance(spec, MLAAttentionSpec)
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
|
||||
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
|
||||
block_size=spec.block_size,
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
cache_dtype_str=spec.cache_dtype_str,
|
||||
)
|
||||
|
||||
elif isinstance(attn_module, MambaBase):
|
||||
mamba_layers[layer_name] = attn_module
|
||||
@@ -3129,16 +3203,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _get_sparse_kv_cache_ratio(self) -> list[int]:
|
||||
# TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes
|
||||
# when calculating the ratio,for example:
|
||||
# [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...]
|
||||
return [
|
||||
self.model_config.hf_text_config.kv_lora_rank,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
]
|
||||
|
||||
def _check_and_update_cudagraph_mode(
|
||||
self,
|
||||
attention_backends: list[set[type[AttentionBackend]]],
|
||||
|
||||
Reference in New Issue
Block a user