[Feature] Support kv nz feature for DeepSeek decode node in disagg-prefill scenario (#3072)
By converting the KV cache from ND to NZ format when the decode node
receives it, this PR ensures that the KV NZ feature works correctly
during the decoding phase in disagg-prefill scenario.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: ghphotoframe <854746559@qq.com>
Co-authored-by: alex101-ops <alex1015718386@gmail.com>
This commit is contained in:
@@ -745,6 +745,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
self.enable_kv_nz = ascend_config.enable_kv_nz
|
||||
|
||||
self.ring_mla_mask_size = 512
|
||||
|
||||
@@ -1073,7 +1074,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA"
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
@@ -1143,37 +1144,57 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
actual_seq_lengths = None
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
if self.enable_kv_nz:
|
||||
nz_fmt_last_dim = 16
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads,
|
||||
self.kv_lora_rank // nz_fmt_last_dim,
|
||||
block_size, nz_fmt_last_dim)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads,
|
||||
self.qk_rope_head_dim // nz_fmt_last_dim,
|
||||
block_size, nz_fmt_last_dim)
|
||||
else:
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
|
||||
attn_output_shape: tuple | None = None
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.DecodeOnly,
|
||||
] and self.speculative_config is not None:
|
||||
# Input shape: [num_tokens, num_heads, dim]
|
||||
# Output shape: [num_heads, num_tokens, dim]
|
||||
# The right part layout indicates the layout of the attention
|
||||
# output. It is set to NTD to avoid the need for a transpose
|
||||
# operation after attention.
|
||||
input_layout = "TND_NTD"
|
||||
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
|
||||
# Input shape: [num_tokens, num_heads, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
|
||||
# Output shape: [num_heads, num_tokens, dim]
|
||||
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
|
||||
sparse_mode = 3
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
else:
|
||||
# Input shape: [num_reqs, num_heads, seq_len, dim]
|
||||
# Output shape: [num_heads, num_reqs, seq_len, dim]
|
||||
# The output layout is set to NBSD to eliminate the need for a
|
||||
# transpose operation after attention.
|
||||
input_layout = "BNSD_NBSD"
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
if self.enable_kv_nz:
|
||||
# Input shape: [num_tokens, seq_len, num_heads, dim]
|
||||
input_layout = "BSND_NBSD"
|
||||
q_nope = q_nope.view(num_tokens, 1, self.num_heads,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
||||
else:
|
||||
# Input shape: [num_tokens, num_heads, seq_len, dim]
|
||||
input_layout = "BNSD_NBSD"
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
# Output shape: [num_heads, num_tokens, seq_len, dim]
|
||||
attn_output_shape = (self.num_heads, num_tokens, 1,
|
||||
self.kv_lora_rank)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
|
||||
@@ -1215,10 +1236,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
else:
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
attn_output = torch.empty(
|
||||
(q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
attn_output = torch.empty(attn_output_shape,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
softmax_lse = torch.empty(num_tokens,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
@@ -1297,7 +1317,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
bias1=self.qb_qt_bias,
|
||||
ctkv_scale=self.ctkv_scale,
|
||||
q_nope_scale=self.q_nope_scale,
|
||||
cache_mode="krope_ctkv",
|
||||
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
q_out0=decode_q_nope,
|
||||
kv_cache_out0=decode_k_nope,
|
||||
|
||||
Reference in New Issue
Block a user