Enable kvcache_nz for the decode process in torchair graph mode (#1098)

What this PR does / why we need it?
Enable kvcache_nz for the decode process in torchair graph mode, which
reduces the time consumed by FA in long sequences.

Does this PR introduce any user-facing change?
If need to enable kvcache_nz, should set the
additional_config.torchair_graph_config.enable_kv_nz=True

How was this patch tested?
1. Tested in deepseek model:
with batchsize 64 and seq_len 1k+3k, 61 layers FA total time improves
20.80ms -> 19.76ms
2. operator precision test: 

[aclnnFusedInferAttentionScoreV3_result.csv](https://github.com/user-attachments/files/20664138/aclnnFusedInferAttentionScoreV3_result.csv)
3. tpot test from @ttanzhiqiang, and curl one result is normal

https://github.com/vllm-project/vllm-ascend/pull/1098#issuecomment-2948542159

https://github.com/vllm-project/vllm-ascend/pull/1098#issuecomment-2954496588

---------

Signed-off-by: chenwaner <861645847@qq.com>
This commit is contained in:
chenwaner
2025-06-11 14:09:28 +08:00
committed by GitHub
parent 4153a5091b
commit e46dc142bf
3 changed files with 96 additions and 47 deletions

View File

@@ -44,6 +44,7 @@ The details of each config option are as follows:
| `use_cached_graph` | bool | `False` | Whether to use cached graph | | `use_cached_graph` | bool | `False` | Whether to use cached graph |
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout |
**ascend_scheduler_config** **ascend_scheduler_config**
@@ -64,7 +65,8 @@ A full example of additional configuration is as follows:
"use_cached_graph": true, "use_cached_graph": true,
"graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": false, "graph_batch_sizes_init": false,
"enable_multistream_moe": false "enable_multistream_moe": false,
"enable_kv_nz": false
}, },
"ascend_scheduler_config": { "ascend_scheduler_config": {
"enabled": true, "enabled": true,

View File

@@ -58,6 +58,7 @@ class TorchairGraphConfig:
"enable_multistream_moe", False) "enable_multistream_moe", False)
self.enable_view_optimize = torchair_graph_config.get( self.enable_view_optimize = torchair_graph_config.get(
"enable_view_optimize", True) "enable_view_optimize", True)
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
if not isinstance(self.graph_batch_sizes, list): if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]") raise TypeError("graph_batch_sizes must be list[int]")

View File

@@ -480,6 +480,7 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
# Adapt torch air graph mode with spec decoding. # Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config speculative_config = get_current_vllm_config().speculative_config
if speculative_config is not None: if speculative_config is not None:
@@ -662,6 +663,7 @@ class AscendMLAImpl(MLAAttentionImpl):
kv = self.kv_a_proj_with_mqa(hidden_states)[0] kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D] # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv, kv,
self.kv_a_layernorm.weight, self.kv_a_layernorm.weight,
@@ -671,7 +673,37 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_cache[1], kv_cache[1],
kv_cache[0], kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon, epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA", cache_mode=cache_mode,
)
return k_pe, k_nope
def exec_kv_prefill(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
):
B = hidden_states.shape[0]
N = self.num_kv_heads
S = 1
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
is_output_kv=True,
) )
return k_pe, k_nope return k_pe, k_nope
@@ -709,34 +741,42 @@ class AscendMLAImpl(MLAAttentionImpl):
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0 assert num_tokens % self.spec_token_num == 0
q_nope = (q_nope.view( q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
num_tokens // (self.spec_token_num + 1), self.spec_token_num + 1, self.num_heads,
self.spec_token_num + 1, -1)
self.num_heads, q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
-1, self.spec_token_num + 1, self.num_heads, -1)
).transpose(1, 2).contiguous()) if not self.enable_kv_nz:
q_pe = (q_pe.view( q_nope = q_nope.transpose(1, 2).contiguous()
num_tokens // (self.spec_token_num + 1), q_pe = q_pe.transpose(1, 2).contiguous()
self.spec_token_num + 1,
self.num_heads,
-1,
).transpose(1, 2).contiguous())
sparse_mode = 3 sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
else: else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) if self.enable_kv_nz:
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0 sparse_mode = 0
spec_attn_mask = None spec_attn_mask = None
# shape of knope/k_pe for npu graph mode should be: # shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1] block_size = kv_c_and_k_pe_cache[0].shape[1]
k_nope = k_nope.view(-1, self.num_kv_heads, block_size, if self.enable_kv_nz:
self.kv_lora_rank) k_nope = k_nope.view(-1, self.num_kv_heads,
k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank // 16, block_size, 16)
self.qk_rope_head_dim) k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // 16, block_size, 16)
input_layout = "BSND"
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope, q_nope,
k_nope, k_nope,
k_nope, k_nope,
@@ -744,7 +784,7 @@ class AscendMLAImpl(MLAAttentionImpl):
key_rope=k_pe, key_rope=k_pe,
num_heads=self.num_heads, num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads, num_key_value_heads=self.num_kv_heads,
input_layout="BNSD", input_layout=input_layout,
atten_mask=spec_attn_mask, atten_mask=spec_attn_mask,
sparse_mode=sparse_mode, sparse_mode=sparse_mode,
scale=self.scale, scale=self.scale,
@@ -793,10 +833,11 @@ class AscendMLAImpl(MLAAttentionImpl):
] ]
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph: if k_pe is None and not self.running_in_graph:
kv_c, k_pe = self.kv_a_proj_with_mqa( if not self.torchair_graph_enabled:
hidden_states_or_kv_c_normed)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) hidden_states_or_kv_c_normed)[0].split(
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
else: else:
kv_c_normed = hidden_states_or_kv_c_normed kv_c_normed = hidden_states_or_kv_c_normed
assert attn_metadata.num_decodes is not None and \ assert attn_metadata.num_decodes is not None and \
@@ -809,16 +850,18 @@ class AscendMLAImpl(MLAAttentionImpl):
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
output_padded = output output_padded = output
output = output[:num_actual_toks, ...] output = output[:num_actual_toks, ...]
kv_c_normed = kv_c_normed[:num_actual_toks, ...] if not self.torchair_graph_enabled:
prefill_k_c_normed = kv_c_normed[num_decode_tokens:] kv_c_normed = kv_c_normed[:num_actual_toks, ...]
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
if not self.running_in_graph: if not self.running_in_graph:
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
k_pe = k_pe[:num_actual_toks, ...] if not self.torchair_graph_enabled:
k_pe = k_pe.unsqueeze(1) decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens] k_pe = k_pe[:num_actual_toks, ...]
prefill_k_pe = k_pe[num_decode_tokens:] k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
else: else:
decode_hs_or_q_c = hidden_states_or_q_c decode_hs_or_q_c = hidden_states_or_q_c
if has_decode: if has_decode:
@@ -855,22 +898,25 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.torchair_graph_enabled: if self.torchair_graph_enabled:
num_tokens = prefill_hs_or_q_c.shape[0] num_tokens = prefill_hs_or_q_c.shape[0]
seq_len = self.rotary_emb.max_position_embeddings
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
cos = cos[attn_metadata.prefill.input_positions]
sin = sin[attn_metadata.prefill.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
-1) -1)
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
# NOTE: When scaling not specified
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
prefill_q_pe, prefill_k_pe = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
else:
prefill_q_pe, prefill_k_pe = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
else: else:
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(