[Feature] Support kv nz feature for DeepSeek decode node in disagg-prefill scenario (#3072)

By converting the KV cache from ND to NZ format when the decode node
receives it, this PR ensures that the KV NZ feature works correctly
during the decoding phase in disagg-prefill scenario.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: ghphotoframe <854746559@qq.com>
Co-authored-by: alex101-ops <alex1015718386@gmail.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:04 +08:00
committed by GitHub
parent a539ae753a
commit 38570cfeb6
8 changed files with 163 additions and 95 deletions

View File

@@ -745,6 +745,7 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.enable_kv_nz
self.ring_mla_mask_size = 512
@@ -1073,7 +1074,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA"
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
@@ -1143,37 +1144,57 @@ class AscendMLAImpl(MLAAttentionImpl):
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
actual_seq_lengths = None
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
if self.enable_kv_nz:
nz_fmt_last_dim = 16
k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // nz_fmt_last_dim,
block_size, nz_fmt_last_dim)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // nz_fmt_last_dim,
block_size, nz_fmt_last_dim)
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
attn_output_shape: tuple | None = None
if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
] and self.speculative_config is not None:
# Input shape: [num_tokens, num_heads, dim]
# Output shape: [num_heads, num_tokens, dim]
# The right part layout indicates the layout of the attention
# output. It is set to NTD to avoid the need for a transpose
# operation after attention.
input_layout = "TND_NTD"
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
# Input shape: [num_tokens, num_heads, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
# Output shape: [num_heads, num_tokens, dim]
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
# Input shape: [num_reqs, num_heads, seq_len, dim]
# Output shape: [num_heads, num_reqs, seq_len, dim]
# The output layout is set to NBSD to eliminate the need for a
# transpose operation after attention.
input_layout = "BNSD_NBSD"
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
-1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
if self.enable_kv_nz:
# Input shape: [num_tokens, seq_len, num_heads, dim]
input_layout = "BSND_NBSD"
q_nope = q_nope.view(num_tokens, 1, self.num_heads,
-1).contiguous()
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
# Input shape: [num_tokens, num_heads, seq_len, dim]
input_layout = "BNSD_NBSD"
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
-1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
# Output shape: [num_heads, num_tokens, seq_len, dim]
attn_output_shape = (self.num_heads, num_tokens, 1,
self.kv_lora_rank)
sparse_mode = 0
spec_attn_mask = None
@@ -1215,10 +1236,9 @@ class AscendMLAImpl(MLAAttentionImpl):
else:
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty(
(q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]),
dtype=q_nope.dtype,
device=q_nope.device)
attn_output = torch.empty(attn_output_shape,
dtype=q_nope.dtype,
device=q_nope.device)
softmax_lse = torch.empty(num_tokens,
dtype=q_nope.dtype,
device=q_nope.device)
@@ -1297,7 +1317,7 @@ class AscendMLAImpl(MLAAttentionImpl):
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="krope_ctkv",
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=decode_q_nope,
kv_cache_out0=decode_k_nope,