Enable kvcache_nz for the decode process in torchair graph mode (#1098)
What this PR does / why we need it? Enable kvcache_nz for the decode process in torchair graph mode, which reduces the time consumed by FA in long sequences. Does this PR introduce any user-facing change? If need to enable kvcache_nz, should set the additional_config.torchair_graph_config.enable_kv_nz=True How was this patch tested? 1. Tested in deepseek model: with batchsize 64 and seq_len 1k+3k, 61 layers FA total time improves 20.80ms -> 19.76ms 2. operator precision test: [aclnnFusedInferAttentionScoreV3_result.csv](https://github.com/user-attachments/files/20664138/aclnnFusedInferAttentionScoreV3_result.csv) 3. tpot test from @ttanzhiqiang, and curl one result is normal https://github.com/vllm-project/vllm-ascend/pull/1098#issuecomment-2948542159 https://github.com/vllm-project/vllm-ascend/pull/1098#issuecomment-2954496588 --------- Signed-off-by: chenwaner <861645847@qq.com>
This commit is contained in:
@@ -44,6 +44,7 @@ The details of each config option are as follows:
|
||||
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
|
||||
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
|
||||
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
|
||||
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout |
|
||||
|
||||
**ascend_scheduler_config**
|
||||
|
||||
@@ -64,7 +65,8 @@ A full example of additional configuration is as follows:
|
||||
"use_cached_graph": true,
|
||||
"graph_batch_sizes": [1, 2, 4, 8],
|
||||
"graph_batch_sizes_init": false,
|
||||
"enable_multistream_moe": false
|
||||
"enable_multistream_moe": false,
|
||||
"enable_kv_nz": false
|
||||
},
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": true,
|
||||
|
||||
@@ -58,6 +58,7 @@ class TorchairGraphConfig:
|
||||
"enable_multistream_moe", False)
|
||||
self.enable_view_optimize = torchair_graph_config.get(
|
||||
"enable_view_optimize", True)
|
||||
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
||||
|
||||
if not isinstance(self.graph_batch_sizes, list):
|
||||
raise TypeError("graph_batch_sizes must be list[int]")
|
||||
|
||||
@@ -480,6 +480,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = get_current_vllm_config().speculative_config
|
||||
if speculative_config is not None:
|
||||
@@ -662,6 +663,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
@@ -671,7 +673,37 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode="PA",
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
return k_pe, k_nope
|
||||
|
||||
def exec_kv_prefill(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
|
||||
B = hidden_states.shape[0]
|
||||
N = self.num_kv_heads
|
||||
S = 1
|
||||
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
|
||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
is_output_kv=True,
|
||||
)
|
||||
return k_pe, k_nope
|
||||
|
||||
@@ -709,34 +741,42 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
assert num_tokens % self.spec_token_num == 0
|
||||
q_nope = (q_nope.view(
|
||||
num_tokens // (self.spec_token_num + 1),
|
||||
self.spec_token_num + 1,
|
||||
self.num_heads,
|
||||
-1,
|
||||
).transpose(1, 2).contiguous())
|
||||
q_pe = (q_pe.view(
|
||||
num_tokens // (self.spec_token_num + 1),
|
||||
self.spec_token_num + 1,
|
||||
self.num_heads,
|
||||
-1,
|
||||
).transpose(1, 2).contiguous())
|
||||
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
|
||||
self.spec_token_num + 1, self.num_heads,
|
||||
-1)
|
||||
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
|
||||
self.spec_token_num + 1, self.num_heads, -1)
|
||||
if not self.enable_kv_nz:
|
||||
q_nope = q_nope.transpose(1, 2).contiguous()
|
||||
q_pe = q_pe.transpose(1, 2).contiguous()
|
||||
sparse_mode = 3
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
else:
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
if self.enable_kv_nz:
|
||||
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
|
||||
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
||||
else:
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
if self.enable_kv_nz:
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads,
|
||||
self.kv_lora_rank // 16, block_size, 16)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads,
|
||||
self.qk_rope_head_dim // 16, block_size, 16)
|
||||
input_layout = "BSND"
|
||||
else:
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
@@ -744,7 +784,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BNSD",
|
||||
input_layout=input_layout,
|
||||
atten_mask=spec_attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=self.scale,
|
||||
@@ -793,10 +833,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
]
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
if k_pe is None and not self.running_in_graph:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
if not self.torchair_graph_enabled:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
else:
|
||||
kv_c_normed = hidden_states_or_kv_c_normed
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
@@ -809,16 +850,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
|
||||
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
|
||||
if not self.torchair_graph_enabled:
|
||||
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
|
||||
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
|
||||
if not self.running_in_graph:
|
||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
if not self.torchair_graph_enabled:
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
else:
|
||||
decode_hs_or_q_c = hidden_states_or_q_c
|
||||
if has_decode:
|
||||
@@ -855,22 +898,25 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||
if self.torchair_graph_enabled:
|
||||
num_tokens = prefill_hs_or_q_c.shape[0]
|
||||
seq_len = self.rotary_emb.max_position_embeddings
|
||||
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
||||
dtype=prefill_q_pe.dtype)
|
||||
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
||||
dtype=prefill_q_pe.dtype)
|
||||
cos = cos[attn_metadata.prefill.input_positions]
|
||||
sin = sin[attn_metadata.prefill.input_positions]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping)
|
||||
|
||||
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
|
||||
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
|
||||
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
||||
-1)
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
# NOTE: When scaling not specified
|
||||
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
|
||||
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
|
||||
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
|
||||
prefill_q_pe, prefill_k_pe = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||
prefill_k_pe)
|
||||
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
|
||||
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
|
||||
else:
|
||||
prefill_q_pe, prefill_k_pe = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||
prefill_k_pe)
|
||||
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
||||
else:
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
|
||||
Reference in New Issue
Block a user