[Feat] chunkprefill mla support torchair graph (#1772)

chunkprefill mla only support eager mode now,we want to optimaze it by
support torchair graph, the idea is simple, when all the request is
running in decode, use torchair graph to deal with it, else when
chunkprefill or prefill only, use the eager mode

- vLLM version: v0.10.0
- vLLM main:
ebf7605b0d

Signed-off-by: haojiangzheng <justineric096@gmail.com>
Co-authored-by: haojiangzheng <justineric096@gmail.com>
This commit is contained in:
zhenghaojiang
2025-08-11 19:58:59 +08:00
committed by GitHub
parent 881e36d6a9
commit eb43a475f4
2 changed files with 28 additions and 18 deletions

View File

@@ -664,6 +664,7 @@ class TestAscendMLAImpl(TestBase):
def test_forward_decode_without_graph(self, mock_page_attention_mla, def test_forward_decode_without_graph(self, mock_page_attention_mla,
mock_up_proj): mock_up_proj):
self.impl.running_in_graph = False self.impl.running_in_graph = False
self.impl.running_chunkprefilll_with_torchair = False
num_tokens = 100 num_tokens = 100
num_blocks = 256 num_blocks = 256
block_size = 4 block_size = 4

View File

@@ -998,7 +998,7 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_meta = attn_metadata.decode decode_meta = attn_metadata.decode
assert decode_meta is not None assert decode_meta is not None
num_tokens = q_nope.size(0) num_tokens = q_nope.size(0)
if self.running_in_graph: if self.running_in_graph or self.running_chunkprefilll_with_torchair:
# shape of knope/k_pe for npu graph mode should be: # shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1] block_size = kv_c_and_k_pe_cache[0].shape[1]
@@ -1112,6 +1112,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
] ]
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph: if k_pe is None and not self.running_in_graph:
kv_c, k_pe = self.kv_a_proj_with_mqa( kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -1148,18 +1149,25 @@ class AscendMLAImpl(MLAAttentionImpl):
if has_decode: if has_decode:
decode_k_nope = None decode_k_nope = None
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if self.running_in_graph: if self.running_in_graph or self.running_chunkprefilll_with_torchair:
cos = attn_metadata.decode.cos cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin sin = attn_metadata.decode.sin
with npu_stream_switch("mla_secondary", if self.running_chunkprefilll_with_torchair:
0, decode_hs = (
enabled=enable_multistream_mla): hidden_states_or_kv_c_normed[:num_decode_tokens])
npu_wait_tensor(hidden_states_or_kv_c_normed, slots = attn_metadata.slot_mapping[:num_decode_tokens]
ckq,
enabled=enable_multistream_mla)
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache, decode_hs, cos, sin, kv_cache, slots)
attn_metadata.slot_mapping) else:
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
npu_wait_tensor(hidden_states_or_kv_c_normed,
ckq,
enabled=enable_multistream_mla)
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
# Without explicitly controlling the order, IndexByTensor operations # Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of # would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope. # KvRmsNormRopeCache and SingleRope.
@@ -1183,6 +1191,8 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_k_pe, decode_k_pe,
enabled=enable_multistream_mla) enabled=enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
elif self.running_chunkprefilll_with_torchair:
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else: else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, attn_metadata.decode.input_positions,
@@ -1221,16 +1231,15 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_cache kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
if self.torchair_graph_enabled: if self.torchair_graph_enabled:
if kv_cache[0].numel( if kv_cache[0].numel() > 0 and has_prefill:
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
slots = attn_metadata.slot_mapping slots = attn_metadata.slot_mapping
# NOTE: Separate the kv cache in advance to avoid OOM or other issues # NOTE: Separate the kv cache in advance to avoid OOM or other issues
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view( torch_npu._npu_reshape_and_cache(
num_tokens, self.num_kv_heads, -1), key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
value=prefill_k_pe, value=prefill_k_pe,
key_cache=kv_cache[0], key_cache=kv_cache[0],
value_cache=kv_cache[1], value_cache=kv_cache[1],
slot_indices=slots) slot_indices=slots[num_decode_tokens:])
else: else:
kv_c_normed = kv_c_normed.view( kv_c_normed = kv_c_normed.view(
[num_actual_toks, self.num_kv_heads, -1]) [num_actual_toks, self.num_kv_heads, -1])