Files
xc-llm-ascend/vllm_ascend/patch/platform/patch_kv_cache_interface.py
rjg-lyh 7ed9e9de69 [Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)
### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.

Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.

The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.


### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
2026-03-13 14:47:42 +08:00

139 lines
5.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
import vllm.v1.kv_cache_interface
from typing_extensions import Self
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@dataclass(frozen=True)
class AscendMLAAttentionSpec(MLAAttentionSpec):
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
When Sparse C8 is enabled, the KV cache tuple changes from
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
to
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
The semantic meaning of each KV cache entry is as follows:
1. kv_cache[0] stores kv_lora.
2. kv_cache[1] stores k_rope.
3. kv_cache[2] stores the key tensor from the indexer module.
4. kv_cache[3] stores the key scale tensor from the indexer module,
and exists only when Sparse C8 is enabled.
The main changes are as follows:
1. The key tensor from the indexer module stored in kv_cache[2] is
converted from bf16 to int8 to reduce memory usage. It is then
processed with int8 precision in Lightning_indexer computation
to improve computational efficiency.
2. The quantization scale of the key tensor in the indexer module
must also be stored for the Lightning_indexer_quant operator,
and is therefore saved in kv_cache[3].
"""
sparse_head_dim: tuple[int, ...] | None = None
cache_sparse_c8: bool = False
c8_k_cache_dtype: torch.dtype = torch.int8
c8_k_scale_cache_dtype: torch.dtype = torch.float16
@property
def page_size_bytes(self) -> int:
if self.cache_sparse_c8:
assert self.sparse_head_dim is not None
assert len(self.sparse_head_dim) == 3
num_heads_per_page = self.block_size * self.num_kv_heads
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
# kv_cache[2]: int8
index_head_dim = self.sparse_head_dim[-1]
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
# kv_cache[3]: float16
# since the scale is stored per token, head_dim is set to 1.
index_scale_head_dim = 1
indexer_k_scale_bytes = (
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
)
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
@property
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
"""
Compute the relative byte share of each KV cache entry.
Returns:
A tuple containing the ratios for:
- kv_cache[0]
- kv_cache[1]
- kv_cache[2]
- kv_cache[3] (None if Sparse C8 is disabled)
"""
assert self.sparse_head_dim is not None
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
assert self.sparse_head_dim is not None
assert self.cache_sparse_c8 is True
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
index_k_head_dim_virtual = index_k_head_dim // factor
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
index_k_scale_head_dim_virtual = 1
return (
kv_lora_rank,
qk_rope_head_dim,
index_k_head_dim_virtual,
index_k_scale_head_dim_virtual,
)
if self.cache_sparse_c8:
virtual_dims = get_sparse_head_dim_virtual()
total_virtual_head_dim = sum(virtual_dims)
return (
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
)
return (
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
None, # kv_cache[3] does not exist
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same quantization method."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
sparse_head_dim=specs[0].sparse_head_dim,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
cache_sparse_c8=specs[0].cache_sparse_c8,
)
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec