### What this PR does / why we need it? This reverts commit7ed9e9de69, which introduces an issue that the patch doesn't work with recompute scheduler enabled. - vLLM version: v0.17.0 - vLLM main:4034c3d32e--------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -18,7 +18,6 @@ import os
|
||||
|
||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
|
||||
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
|
||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import vllm.v1.kv_cache_interface
|
||||
from typing_extensions import Self
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AscendMLAAttentionSpec(MLAAttentionSpec):
|
||||
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
|
||||
|
||||
When Sparse C8 is enabled, the KV cache tuple changes from
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
|
||||
to
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
|
||||
|
||||
The semantic meaning of each KV cache entry is as follows:
|
||||
1. kv_cache[0] stores kv_lora.
|
||||
2. kv_cache[1] stores k_rope.
|
||||
3. kv_cache[2] stores the key tensor from the indexer module.
|
||||
4. kv_cache[3] stores the key scale tensor from the indexer module,
|
||||
and exists only when Sparse C8 is enabled.
|
||||
|
||||
The main changes are as follows:
|
||||
1. The key tensor from the indexer module stored in kv_cache[2] is
|
||||
converted from bf16 to int8 to reduce memory usage. It is then
|
||||
processed with int8 precision in Lightning_indexer computation
|
||||
to improve computational efficiency.
|
||||
2. The quantization scale of the key tensor in the indexer module
|
||||
must also be stored for the Lightning_indexer_quant operator,
|
||||
and is therefore saved in kv_cache[3].
|
||||
"""
|
||||
|
||||
sparse_head_dim: tuple[int, ...] | None = None
|
||||
cache_sparse_c8: bool = False
|
||||
c8_k_cache_dtype: torch.dtype = torch.int8
|
||||
c8_k_scale_cache_dtype: torch.dtype = torch.float16
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
if self.cache_sparse_c8:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert len(self.sparse_head_dim) == 3
|
||||
num_heads_per_page = self.block_size * self.num_kv_heads
|
||||
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
|
||||
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
|
||||
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
|
||||
# kv_cache[2]: int8
|
||||
index_head_dim = self.sparse_head_dim[-1]
|
||||
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
|
||||
# kv_cache[3]: float16
|
||||
# since the scale is stored per token, head_dim is set to 1.
|
||||
index_scale_head_dim = 1
|
||||
indexer_k_scale_bytes = (
|
||||
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
)
|
||||
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
|
||||
|
||||
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
|
||||
|
||||
@property
|
||||
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
|
||||
"""
|
||||
Compute the relative byte share of each KV cache entry.
|
||||
|
||||
Returns:
|
||||
A tuple containing the ratios for:
|
||||
- kv_cache[0]
|
||||
- kv_cache[1]
|
||||
- kv_cache[2]
|
||||
- kv_cache[3] (None if Sparse C8 is disabled)
|
||||
"""
|
||||
|
||||
assert self.sparse_head_dim is not None
|
||||
|
||||
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert self.cache_sparse_c8 is True
|
||||
|
||||
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
|
||||
|
||||
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
|
||||
index_k_head_dim_virtual = index_k_head_dim // factor
|
||||
|
||||
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
index_k_scale_head_dim_virtual = 1
|
||||
|
||||
return (
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
index_k_head_dim_virtual,
|
||||
index_k_scale_head_dim_virtual,
|
||||
)
|
||||
|
||||
if self.cache_sparse_c8:
|
||||
virtual_dims = get_sparse_head_dim_virtual()
|
||||
total_virtual_head_dim = sum(virtual_dims)
|
||||
|
||||
return (
|
||||
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
|
||||
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
|
||||
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
|
||||
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
|
||||
)
|
||||
|
||||
return (
|
||||
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
|
||||
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
|
||||
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
|
||||
None, # kv_cache[3] does not exist
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be MLAAttentionSpec."
|
||||
)
|
||||
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
||||
assert len(cache_dtype_str_set) == 1, (
|
||||
"All attention layers in the same KV cache group must use the same quantization method."
|
||||
)
|
||||
return cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
sparse_head_dim=specs[0].sparse_head_dim,
|
||||
dtype=specs[0].dtype,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
cache_sparse_c8=specs[0].cache_sparse_c8,
|
||||
)
|
||||
|
||||
|
||||
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec
|
||||
Reference in New Issue
Block a user