[bugfix] restore pr-7029 and fix patch error (#7294)
### What this PR does / why we need it?
This PR restores #7029, which adds W8A8C8 support for dsv3.2/glm5 using
the `lightning_indexer_quant` ops in the pd-mix stage.
The original PR was reverted by #7288 because the patch did not work
with the recompute scheduler.
This PR also fixes the patching issue so that it works correctly with
the recompute scheduler.
### Does this PR introduce _any_ user-facing change?
Yes. To enable LI C8, users need to set the `enable_sparse_c8` option to
`"true"` in `additional_config`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
138
vllm_ascend/patch/platform/patch_kv_cache_interface.py
Normal file
138
vllm_ascend/patch/platform/patch_kv_cache_interface.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import vllm.v1.kv_cache_interface
|
||||
from typing_extensions import Self
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AscendMLAAttentionSpec(MLAAttentionSpec):
|
||||
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
|
||||
|
||||
When Sparse C8 is enabled, the KV cache tuple changes from
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
|
||||
to
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
|
||||
|
||||
The semantic meaning of each KV cache entry is as follows:
|
||||
1. kv_cache[0] stores kv_lora.
|
||||
2. kv_cache[1] stores k_rope.
|
||||
3. kv_cache[2] stores the key tensor from the indexer module.
|
||||
4. kv_cache[3] stores the key scale tensor from the indexer module,
|
||||
and exists only when Sparse C8 is enabled.
|
||||
|
||||
The main changes are as follows:
|
||||
1. The key tensor from the indexer module stored in kv_cache[2] is
|
||||
converted from bf16 to int8 to reduce memory usage. It is then
|
||||
processed with int8 precision in Lightning_indexer computation
|
||||
to improve computational efficiency.
|
||||
2. The quantization scale of the key tensor in the indexer module
|
||||
must also be stored for the Lightning_indexer_quant operator,
|
||||
and is therefore saved in kv_cache[3].
|
||||
"""
|
||||
|
||||
sparse_head_dim: tuple[int, ...] | None = None
|
||||
cache_sparse_c8: bool = False
|
||||
c8_k_cache_dtype: torch.dtype = torch.int8
|
||||
c8_k_scale_cache_dtype: torch.dtype = torch.float16
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
if self.cache_sparse_c8:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert len(self.sparse_head_dim) == 3
|
||||
num_heads_per_page = self.block_size * self.num_kv_heads
|
||||
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
|
||||
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
|
||||
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
|
||||
# kv_cache[2]: int8
|
||||
index_head_dim = self.sparse_head_dim[-1]
|
||||
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
|
||||
# kv_cache[3]: float16
|
||||
# since the scale is stored per token, head_dim is set to 1.
|
||||
index_scale_head_dim = 1
|
||||
indexer_k_scale_bytes = (
|
||||
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
)
|
||||
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
|
||||
|
||||
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
|
||||
|
||||
@property
|
||||
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
|
||||
"""
|
||||
Compute the relative byte share of each KV cache entry.
|
||||
|
||||
Returns:
|
||||
A tuple containing the ratios for:
|
||||
- kv_cache[0]
|
||||
- kv_cache[1]
|
||||
- kv_cache[2]
|
||||
- kv_cache[3] (None if Sparse C8 is disabled)
|
||||
"""
|
||||
|
||||
assert self.sparse_head_dim is not None
|
||||
|
||||
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert self.cache_sparse_c8 is True
|
||||
|
||||
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
|
||||
|
||||
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
|
||||
index_k_head_dim_virtual = index_k_head_dim // factor
|
||||
|
||||
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
index_k_scale_head_dim_virtual = 1
|
||||
|
||||
return (
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
index_k_head_dim_virtual,
|
||||
index_k_scale_head_dim_virtual,
|
||||
)
|
||||
|
||||
if self.cache_sparse_c8:
|
||||
virtual_dims = get_sparse_head_dim_virtual()
|
||||
total_virtual_head_dim = sum(virtual_dims)
|
||||
|
||||
return (
|
||||
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
|
||||
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
|
||||
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
|
||||
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
|
||||
)
|
||||
|
||||
return (
|
||||
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
|
||||
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
|
||||
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
|
||||
None, # kv_cache[3] does not exist
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be MLAAttentionSpec."
|
||||
)
|
||||
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
||||
assert len(cache_dtype_str_set) == 1, (
|
||||
"All attention layers in the same KV cache group must use the same quantization method."
|
||||
)
|
||||
return cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
sparse_head_dim=specs[0].sparse_head_dim,
|
||||
dtype=specs[0].dtype,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
cache_sparse_c8=specs[0].cache_sparse_c8,
|
||||
)
|
||||
|
||||
|
||||
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec
|
||||
Reference in New Issue
Block a user