[Feat] chunkprefill mla support torchair graph (#1772)

chunkprefill mla only support eager mode now,we want to optimaze it by
support torchair graph, the idea is simple, when all the request is
running in decode, use torchair graph to deal with it, else when
chunkprefill or prefill only, use the eager mode

- vLLM version: v0.10.0
- vLLM main:
ebf7605b0d

Signed-off-by: haojiangzheng <justineric096@gmail.com>
Co-authored-by: haojiangzheng <justineric096@gmail.com>
This commit is contained in:
zhenghaojiang
2025-08-11 19:58:59 +08:00
committed by GitHub
parent 881e36d6a9
commit eb43a475f4
2 changed files with 28 additions and 18 deletions

View File

@@ -664,6 +664,7 @@ class TestAscendMLAImpl(TestBase):
def test_forward_decode_without_graph(self, mock_page_attention_mla,
mock_up_proj):
self.impl.running_in_graph = False
self.impl.running_chunkprefilll_with_torchair = False
num_tokens = 100
num_blocks = 256
block_size = 4

View File

@@ -998,7 +998,7 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_meta = attn_metadata.decode
assert decode_meta is not None
num_tokens = q_nope.size(0)
if self.running_in_graph:
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1]
@@ -1112,6 +1112,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph:
kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -1148,18 +1149,25 @@ class AscendMLAImpl(MLAAttentionImpl):
if has_decode:
decode_k_nope = None
assert attn_metadata.decode is not None
if self.running_in_graph:
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
npu_wait_tensor(hidden_states_or_kv_c_normed,
ckq,
enabled=enable_multistream_mla)
if self.running_chunkprefilll_with_torchair:
decode_hs = (
hidden_states_or_kv_c_normed[:num_decode_tokens])
slots = attn_metadata.slot_mapping[:num_decode_tokens]
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
decode_hs, cos, sin, kv_cache, slots)
else:
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
npu_wait_tensor(hidden_states_or_kv_c_normed,
ckq,
enabled=enable_multistream_mla)
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
@@ -1183,6 +1191,8 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_k_pe,
enabled=enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
elif self.running_chunkprefilll_with_torchair:
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions,
@@ -1221,16 +1231,15 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
if self.torchair_graph_enabled:
if kv_cache[0].numel(
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
if kv_cache[0].numel() > 0 and has_prefill:
slots = attn_metadata.slot_mapping
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
num_tokens, self.num_kv_heads, -1),
value=prefill_k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots)
torch_npu._npu_reshape_and_cache(
key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
value=prefill_k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots[num_decode_tokens:])
else:
kv_c_normed = kv_c_normed.view(
[num_actual_toks, self.num_kv_heads, -1])