### What this PR does / why we need it?
This PR restores #7029, which adds W8A8C8 support for dsv3.2/glm5 using
the `lightning_indexer_quant` ops in the pd-mix stage.
The original PR was reverted by #7288 because the patch did not work
with the recompute scheduler.
This PR also fixes the patching issue so that it works correctly with
the recompute scheduler.
### Does this PR introduce _any_ user-facing change?
Yes. To enable LI C8, users need to set the `enable_sparse_c8` option to
`"true"` in `additional_config`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
139 lines
5.7 KiB
Python
139 lines
5.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import vllm.v1.kv_cache_interface
|
|
from typing_extensions import Self
|
|
from vllm.utils.torch_utils import get_dtype_size
|
|
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AscendMLAAttentionSpec(MLAAttentionSpec):
|
|
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
|
|
|
|
When Sparse C8 is enabled, the KV cache tuple changes from
|
|
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
|
|
to
|
|
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
|
|
|
|
The semantic meaning of each KV cache entry is as follows:
|
|
1. kv_cache[0] stores kv_lora.
|
|
2. kv_cache[1] stores k_rope.
|
|
3. kv_cache[2] stores the key tensor from the indexer module.
|
|
4. kv_cache[3] stores the key scale tensor from the indexer module,
|
|
and exists only when Sparse C8 is enabled.
|
|
|
|
The main changes are as follows:
|
|
1. The key tensor from the indexer module stored in kv_cache[2] is
|
|
converted from bf16 to int8 to reduce memory usage. It is then
|
|
processed with int8 precision in Lightning_indexer computation
|
|
to improve computational efficiency.
|
|
2. The quantization scale of the key tensor in the indexer module
|
|
must also be stored for the Lightning_indexer_quant operator,
|
|
and is therefore saved in kv_cache[3].
|
|
"""
|
|
|
|
sparse_head_dim: tuple[int, ...] | None = None
|
|
cache_sparse_c8: bool = False
|
|
c8_k_cache_dtype: torch.dtype = torch.int8
|
|
c8_k_scale_cache_dtype: torch.dtype = torch.float16
|
|
|
|
@property
|
|
def page_size_bytes(self) -> int:
|
|
if self.cache_sparse_c8:
|
|
assert self.sparse_head_dim is not None
|
|
assert len(self.sparse_head_dim) == 3
|
|
num_heads_per_page = self.block_size * self.num_kv_heads
|
|
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
|
|
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
|
|
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
|
|
# kv_cache[2]: int8
|
|
index_head_dim = self.sparse_head_dim[-1]
|
|
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
|
|
# kv_cache[3]: float16
|
|
# since the scale is stored per token, head_dim is set to 1.
|
|
index_scale_head_dim = 1
|
|
indexer_k_scale_bytes = (
|
|
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
|
|
)
|
|
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
|
|
|
|
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
|
|
|
|
@property
|
|
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
|
|
"""
|
|
Compute the relative byte share of each KV cache entry.
|
|
|
|
Returns:
|
|
A tuple containing the ratios for:
|
|
- kv_cache[0]
|
|
- kv_cache[1]
|
|
- kv_cache[2]
|
|
- kv_cache[3] (None if Sparse C8 is disabled)
|
|
"""
|
|
|
|
assert self.sparse_head_dim is not None
|
|
|
|
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
|
|
assert self.sparse_head_dim is not None
|
|
assert self.cache_sparse_c8 is True
|
|
|
|
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
|
|
|
|
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
|
|
index_k_head_dim_virtual = index_k_head_dim // factor
|
|
|
|
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
|
|
index_k_scale_head_dim_virtual = 1
|
|
|
|
return (
|
|
kv_lora_rank,
|
|
qk_rope_head_dim,
|
|
index_k_head_dim_virtual,
|
|
index_k_scale_head_dim_virtual,
|
|
)
|
|
|
|
if self.cache_sparse_c8:
|
|
virtual_dims = get_sparse_head_dim_virtual()
|
|
total_virtual_head_dim = sum(virtual_dims)
|
|
|
|
return (
|
|
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
|
|
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
|
|
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
|
|
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
|
|
)
|
|
|
|
return (
|
|
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
|
|
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
|
|
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
|
|
None, # kv_cache[3] does not exist
|
|
)
|
|
|
|
@classmethod
|
|
def merge(cls, specs: list[Self]) -> Self:
|
|
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
|
"All attention layers in the same KV cache group must be MLAAttentionSpec."
|
|
)
|
|
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
|
assert len(cache_dtype_str_set) == 1, (
|
|
"All attention layers in the same KV cache group must use the same quantization method."
|
|
)
|
|
return cls(
|
|
block_size=specs[0].block_size,
|
|
num_kv_heads=specs[0].num_kv_heads,
|
|
head_size=specs[0].head_size,
|
|
sparse_head_dim=specs[0].sparse_head_dim,
|
|
dtype=specs[0].dtype,
|
|
cache_dtype_str=cache_dtype_str_set.pop(),
|
|
cache_sparse_c8=specs[0].cache_sparse_c8,
|
|
)
|
|
|
|
|
|
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec
|