[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)
### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.
Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.
The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -84,6 +84,7 @@ from vllm.v1.worker.ubatch_utils import (
|
||||
)
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
# yapf: enable
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention
|
||||
@@ -96,8 +97,6 @@ from vllm_ascend.compilation.acl_graph import (
|
||||
set_graph_params,
|
||||
update_full_graph_params,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
|
||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
@@ -274,7 +273,21 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# Set up Attention
|
||||
self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk")
|
||||
self.use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
if self.use_sparse:
|
||||
self.sparse_head_dim = (
|
||||
self.model_config.hf_text_config.kv_lora_rank,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
)
|
||||
# dsa c8
|
||||
self.use_sparse_c8_indexer = self.ascend_config.enable_sparse_c8
|
||||
if self.use_sparse_c8_indexer:
|
||||
self.c8_k_cache_dtype = torch.int8
|
||||
self.c8_k_scale_cache_dtype = torch.float16
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
0,
|
||||
self.dtype,
|
||||
@@ -2623,7 +2636,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
to their corresponding memory buffer for K cache and V cache.
|
||||
"""
|
||||
# init kv cache tensors
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {}
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {}
|
||||
# prefill disaggregation need the addr of cache tensor be aligned with 2M
|
||||
alignment = 2 * 1024 * 1024
|
||||
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
@@ -2670,19 +2683,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
+ self.model_config.hf_text_config.kv_lora_rank
|
||||
)
|
||||
|
||||
dsa_k_cache_factor = None
|
||||
dsa_k_cache_size = None
|
||||
if not self.model_config.use_mla:
|
||||
# for non-mla model, use FullAttentionSpec
|
||||
k_tensor_split_factor = 2
|
||||
v_tensor_split_factor = 2
|
||||
k_tensor_split_factor = 2.0
|
||||
v_tensor_split_factor = 2.0
|
||||
elif self.use_sparse:
|
||||
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
|
||||
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
|
||||
k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore
|
||||
sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio()
|
||||
]
|
||||
dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor)
|
||||
kv_cache_spec = layer_kv_cache_spec[layer_name]
|
||||
sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
|
||||
k_tensor_split_factor = sparse_kv_cache_ratio[0]
|
||||
v_tensor_split_factor = sparse_kv_cache_ratio[1]
|
||||
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
|
||||
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
|
||||
else:
|
||||
# for other deepseek models, use MLAAttentionSpec
|
||||
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
|
||||
@@ -2690,35 +2702,56 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor)
|
||||
v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor)
|
||||
dsa_k_tensor_size = None
|
||||
dsa_k_scale_tensor_size = None
|
||||
#### for deepseek sparse attention
|
||||
if self.use_sparse:
|
||||
dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor)
|
||||
if self.use_sparse_c8_indexer:
|
||||
dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor)
|
||||
|
||||
# for other attentions, e.g., self_attn, sliding window attn
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device)
|
||||
v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device)
|
||||
#### k cache: for deepseek sparse attention
|
||||
if dsa_k_cache_factor is not None:
|
||||
dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device)
|
||||
#### for deepseek sparse attention
|
||||
if dsa_k_tensor_size is not None:
|
||||
dsa_k_tensor = torch.zeros(dsa_k_tensor_size, dtype=torch.int8, device=self.device)
|
||||
if dsa_k_scale_tensor_size is not None:
|
||||
dsa_k_scale_tensor = torch.zeros(
|
||||
dsa_k_scale_tensor_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
else:
|
||||
k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device)
|
||||
v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device)
|
||||
k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size]
|
||||
v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size]
|
||||
#### k cache: for deepseek sparse attention
|
||||
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
|
||||
dsa_k_cache_tensor = torch.zeros(
|
||||
dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device
|
||||
#### for deepseek sparse attention
|
||||
if dsa_k_tensor_size is not None:
|
||||
dsa_k_tensor = torch.zeros(
|
||||
dsa_k_tensor_size + alignment, dtype=torch.int8, device=self.device
|
||||
)
|
||||
dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size]
|
||||
dsa_k_tensor = self._align_memory(dsa_k_tensor, alignment)[:dsa_k_tensor_size]
|
||||
if dsa_k_scale_tensor_size is not None:
|
||||
dsa_k_scale_tensor = torch.zeros(
|
||||
dsa_k_scale_tensor_size + alignment, dtype=torch.int8, device=self.device
|
||||
)
|
||||
dsa_k_scale_tensor = self._align_memory(
|
||||
dsa_k_scale_tensor, alignment
|
||||
)[:dsa_k_scale_tensor_size]
|
||||
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
# shared the attn kvcache for all shared layers
|
||||
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (
|
||||
(k_tensor, v_tensor)
|
||||
if not self.use_sparse
|
||||
else (k_tensor, v_tensor, dsa_k_cache_tensor)
|
||||
)
|
||||
|
||||
if self.use_sparse:
|
||||
if self.use_sparse_c8_indexer:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (
|
||||
k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor
|
||||
)
|
||||
else:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor, dsa_k_tensor)
|
||||
else:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor)
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
@@ -2760,13 +2793,23 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
raw_dsa_k_tensor = None
|
||||
if self.use_sparse:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name
|
||||
]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
|
||||
if self.use_sparse_c8_indexer:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
assert raw_dsa_k_scale_tensor is not None
|
||||
sum_page_size_bytes = (
|
||||
raw_k_tensor.numel()
|
||||
+ raw_v_tensor.numel()
|
||||
+ raw_dsa_k_tensor.numel()
|
||||
+ raw_dsa_k_scale_tensor.numel()
|
||||
)
|
||||
else:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
|
||||
elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba:
|
||||
# Currently, we ensure that the same kvcache format is used even if there
|
||||
# is no shared layer, such as the full attention mtp layer of qwen3.5, etc.
|
||||
@@ -2813,7 +2856,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
|
||||
)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
if not self.model_config.use_mla:
|
||||
k_shape = kv_cache_shape[1:]
|
||||
v_shape = k_shape
|
||||
@@ -2832,19 +2875,37 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_kv_heads,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
]
|
||||
k_cache = raw_k_tensor.view(dtype).view(k_shape)
|
||||
v_cache = raw_v_tensor.view(dtype).view(v_shape)
|
||||
k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape)
|
||||
v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape)
|
||||
|
||||
if self.use_sparse and raw_dsa_k_tensor is not None:
|
||||
index_head_dim = self._get_sparse_kv_cache_ratio()[-1]
|
||||
if self.use_sparse:
|
||||
dsa_k_cache_shape = (
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
index_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
)
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||||
if self.use_sparse_c8_indexer:
|
||||
# dsa_k
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape)
|
||||
# dsa_k_scale
|
||||
dsa_k_scale_cache_shape = (
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
1,
|
||||
)
|
||||
assert raw_dsa_k_scale_tensor is not None
|
||||
dsa_k_scale_cache = (
|
||||
raw_dsa_k_scale_tensor
|
||||
.view(self.c8_k_scale_cache_dtype)
|
||||
.view(dsa_k_scale_cache_shape)
|
||||
)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache)
|
||||
else:
|
||||
# dsa_k
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||||
else:
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
@@ -3098,18 +3159,31 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
if self.use_sparse:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is the final way.
|
||||
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
# `MLAAttentionSpec` is temporarily patched to `AscendMLAAttentionSpec`.
|
||||
# Re-importing it at runtime will therefore resolve to the patched class.
|
||||
# Rename it here to make this behavior explicit.
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
|
||||
# TODO(rjg-lyh): when kv_cache_spec's refactor is ready,
|
||||
# implement it by creating a new kv_cache_spec class
|
||||
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=sparse_sum_head_size,
|
||||
head_size=sum(self.sparse_head_dim),
|
||||
sparse_head_dim=self.sparse_head_dim,
|
||||
dtype=self.kv_cache_dtype,
|
||||
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
|
||||
cache_sparse_c8=self.use_sparse_c8_indexer,
|
||||
)
|
||||
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
assert isinstance(spec, MLAAttentionSpec)
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
|
||||
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
|
||||
block_size=spec.block_size,
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
cache_dtype_str=spec.cache_dtype_str,
|
||||
)
|
||||
|
||||
elif isinstance(attn_module, MambaBase):
|
||||
mamba_layers[layer_name] = attn_module
|
||||
@@ -3129,16 +3203,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _get_sparse_kv_cache_ratio(self) -> list[int]:
|
||||
# TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes
|
||||
# when calculating the ratio,for example:
|
||||
# [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...]
|
||||
return [
|
||||
self.model_config.hf_text_config.kv_lora_rank,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
]
|
||||
|
||||
def _check_and_update_cudagraph_mode(
|
||||
self,
|
||||
attention_backends: list[set[type[AttentionBackend]]],
|
||||
|
||||
Reference in New Issue
Block a user