[bugfix] restore pr-7029 and fix patch error (#7294)
### What this PR does / why we need it?
This PR restores #7029, which adds W8A8C8 support for dsv3.2/glm5 using
the `lightning_indexer_quant` ops in the pd-mix stage.
The original PR was reverted by #7288 because the patch did not work
with the recompute scheduler.
This PR also fixes the patching issue so that it works correctly with
the recompute scheduler.
### Does this PR introduce _any_ user-facing change?
Yes. To enable LI C8, users need to set the `enable_sparse_c8` option to
`"true"` in `additional_config`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -137,6 +137,28 @@
|
||||
# Remove this patch if upstream provides an official NPU graph-capture
|
||||
# guidance / auto-configuration path for HCCL.
|
||||
#
|
||||
# ** 8. File: platform/patch_kv_cache_interface.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
|
||||
# Why:
|
||||
# The default `MLAAttentionSpec` is mainly built around `kv_lora_rank`
|
||||
# and `qk_rope_head_dim`. On NPU, we also use this class to describe DSA
|
||||
# models. Unlike the GPU path, where cache management is handled by an
|
||||
# additional indexer module, extending this class directly simplifies the
|
||||
# corresponding `model_runner` implementation on NPU.
|
||||
#
|
||||
# This patch also adds Sparse C8 support for DSA models on NPU. As part
|
||||
# of that support, members such as `page_size_bytes` need to be adapted,
|
||||
# so they are overridden here as well to preserve overall readability.
|
||||
# How:
|
||||
# This patch subclasses the original implementation, overrides selected
|
||||
# methods, and adds DSA-specific attributes and helpers with default
|
||||
# values where needed.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/25896
|
||||
# Future Plan:
|
||||
# Remove this patch after the upcoming KV cache spec refactor.
|
||||
#
|
||||
# * Worker Patch:
|
||||
# ===============
|
||||
#
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
|
||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
|
||||
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
|
||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||
|
||||
138
vllm_ascend/patch/platform/patch_kv_cache_interface.py
Normal file
138
vllm_ascend/patch/platform/patch_kv_cache_interface.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import vllm.v1.kv_cache_interface
|
||||
from typing_extensions import Self
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AscendMLAAttentionSpec(MLAAttentionSpec):
|
||||
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
|
||||
|
||||
When Sparse C8 is enabled, the KV cache tuple changes from
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
|
||||
to
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
|
||||
|
||||
The semantic meaning of each KV cache entry is as follows:
|
||||
1. kv_cache[0] stores kv_lora.
|
||||
2. kv_cache[1] stores k_rope.
|
||||
3. kv_cache[2] stores the key tensor from the indexer module.
|
||||
4. kv_cache[3] stores the key scale tensor from the indexer module,
|
||||
and exists only when Sparse C8 is enabled.
|
||||
|
||||
The main changes are as follows:
|
||||
1. The key tensor from the indexer module stored in kv_cache[2] is
|
||||
converted from bf16 to int8 to reduce memory usage. It is then
|
||||
processed with int8 precision in Lightning_indexer computation
|
||||
to improve computational efficiency.
|
||||
2. The quantization scale of the key tensor in the indexer module
|
||||
must also be stored for the Lightning_indexer_quant operator,
|
||||
and is therefore saved in kv_cache[3].
|
||||
"""
|
||||
|
||||
sparse_head_dim: tuple[int, ...] | None = None
|
||||
cache_sparse_c8: bool = False
|
||||
c8_k_cache_dtype: torch.dtype = torch.int8
|
||||
c8_k_scale_cache_dtype: torch.dtype = torch.float16
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
if self.cache_sparse_c8:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert len(self.sparse_head_dim) == 3
|
||||
num_heads_per_page = self.block_size * self.num_kv_heads
|
||||
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
|
||||
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
|
||||
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
|
||||
# kv_cache[2]: int8
|
||||
index_head_dim = self.sparse_head_dim[-1]
|
||||
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
|
||||
# kv_cache[3]: float16
|
||||
# since the scale is stored per token, head_dim is set to 1.
|
||||
index_scale_head_dim = 1
|
||||
indexer_k_scale_bytes = (
|
||||
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
)
|
||||
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
|
||||
|
||||
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
|
||||
|
||||
@property
|
||||
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
|
||||
"""
|
||||
Compute the relative byte share of each KV cache entry.
|
||||
|
||||
Returns:
|
||||
A tuple containing the ratios for:
|
||||
- kv_cache[0]
|
||||
- kv_cache[1]
|
||||
- kv_cache[2]
|
||||
- kv_cache[3] (None if Sparse C8 is disabled)
|
||||
"""
|
||||
|
||||
assert self.sparse_head_dim is not None
|
||||
|
||||
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert self.cache_sparse_c8 is True
|
||||
|
||||
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
|
||||
|
||||
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
|
||||
index_k_head_dim_virtual = index_k_head_dim // factor
|
||||
|
||||
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
index_k_scale_head_dim_virtual = 1
|
||||
|
||||
return (
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
index_k_head_dim_virtual,
|
||||
index_k_scale_head_dim_virtual,
|
||||
)
|
||||
|
||||
if self.cache_sparse_c8:
|
||||
virtual_dims = get_sparse_head_dim_virtual()
|
||||
total_virtual_head_dim = sum(virtual_dims)
|
||||
|
||||
return (
|
||||
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
|
||||
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
|
||||
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
|
||||
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
|
||||
)
|
||||
|
||||
return (
|
||||
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
|
||||
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
|
||||
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
|
||||
None, # kv_cache[3] does not exist
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be MLAAttentionSpec."
|
||||
)
|
||||
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
||||
assert len(cache_dtype_str_set) == 1, (
|
||||
"All attention layers in the same KV cache group must use the same quantization method."
|
||||
)
|
||||
return cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
sparse_head_dim=specs[0].sparse_head_dim,
|
||||
dtype=specs[0].dtype,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
cache_sparse_c8=specs[0].cache_sparse_c8,
|
||||
)
|
||||
|
||||
|
||||
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec
|
||||
Reference in New Issue
Block a user