[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)

### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.

Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.

The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.


### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2026-03-13 14:47:42 +08:00
committed by GitHub
parent df1ee8070d
commit 7ed9e9de69
24 changed files with 4279 additions and 77 deletions

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeVar
import scipy # type: ignore
import torch
import torch_npu
import vllm.envs as envs_vllm
@@ -355,6 +356,9 @@ class AscendSFAImpl(MLAAttentionImpl):
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
o_proj_full_pool: torch.Tensor | None = None
# qk_hadamard tensor shared when dsa c8 enabled
qk_hadamard: torch.Tensor | None = None
def __init__(
self,
num_heads: int,
@@ -425,6 +429,12 @@ class AscendSFAImpl(MLAAttentionImpl):
self.is_rope_neox_style = False
self.use_torch_npu_lightning_indexer = True
# dsa c8
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
# Effective in SFA when FlashComm is enabled.
self.enable_dsa_cp = enable_dsa_cp()
@@ -515,6 +525,11 @@ class AscendSFAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
128**0.5
)
# Processing the input parameters for MLAPO by reordering and transposing
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
# and handling quantization parameters.
@@ -874,7 +889,15 @@ class AscendSFAImpl(MLAAttentionImpl):
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
return k_li
if self.use_sparse_c8_indexer:
k_li = k_li @ AscendSFAImpl.qk_hadamard
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
else:
k_li_scale = None
return k_li, k_li_scale
def indexer_select_post_process(
self,
@@ -905,10 +928,35 @@ class AscendSFAImpl(MLAAttentionImpl):
q_li_pe = q_li_pe.squeeze(2)
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
if self.use_sparse_c8_indexer:
q_li_shape_ori = q_li.shape
q_li = q_li @ AscendSFAImpl.qk_hadamard
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
# So two branches are maintained temporarily.
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
if self.use_torch_npu_lightning_indexer:
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
weights = weights.to(torch.float16)
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
query=q_li.view(q_li_shape_ori),
key=kv_cache[2],
weights=weights,
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=attn_metadata.block_table,
query_quant_mode=0,
key_quant_mode=0,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
elif self.use_torch_npu_lightning_indexer:
topk_indices, _ = torch_npu.npu_lightning_indexer(
query=q_li,
key=kv_cache[2],
@@ -1031,7 +1079,7 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
q_c = self.q_a_layernorm(q_c)
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
wait_for_kv_layer_from_connector(layer_name)
@@ -1044,20 +1092,46 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.enable_dsa_cp:
assert k_pe is not None
assert k_nope is not None
assert k_li is not None
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
# support all_gather kv async for communication calculation overlap
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
if not self.use_sparse_c8_indexer:
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
else:
# due to different dtypes, we have to split commu pass
assert k_li_scale is not None
fused_kv_no_split, _ = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
k_li, _ = all_gather_async(
k_li,
get_tp_group(),
async_op=async_op,
)
k_li_scale, kv_ag_handle = all_gather_async(
k_li_scale,
get_tp_group(),
async_op=async_op,
)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
@@ -1077,9 +1151,12 @@ class AscendSFAImpl(MLAAttentionImpl):
if kv_cache is not None:
assert fused_kv_no_split is not None
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
if not self.use_sparse_c8_indexer:
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
else:
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
DeviceOperator.reshape_and_cache(
@@ -1098,6 +1175,13 @@ class AscendSFAImpl(MLAAttentionImpl):
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
) # b, s, n, d
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
torch_npu.npu_scatter_nd_update_(
kv_cache[3].view(-1, k_li_scale.shape[-1]),
slot_mapping.view(-1, 1),
k_li_scale.view(-1, k_li_scale.shape[-1]),
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()