From eb43a475f429192e7509e85e28b1c65d5097f373 Mon Sep 17 00:00:00 2001 From: zhenghaojiang Date: Mon, 11 Aug 2025 19:58:59 +0800 Subject: [PATCH] [Feat] chunkprefill mla support torchair graph (#1772) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit chunkprefill mla only support eager mode now,we want to optimaze it by support torchair graph, the idea is simple, when all the request is running in decode, use torchair graph to deal with it, else when chunkprefill or prefill only, use the eager mode - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ebf7605b0dd58ff5d572d1918e52ca732025eee0 Signed-off-by: haojiangzheng Co-authored-by: haojiangzheng --- tests/ut/attention/test_mla_v1.py | 1 + vllm_ascend/attention/mla_v1.py | 45 ++++++++++++++++++------------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 2ecc3f7..652cff3 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -664,6 +664,7 @@ class TestAscendMLAImpl(TestBase): def test_forward_decode_without_graph(self, mock_page_attention_mla, mock_up_proj): self.impl.running_in_graph = False + self.impl.running_chunkprefilll_with_torchair = False num_tokens = 100 num_blocks = 256 block_size = 4 diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a8f8ae8..48713fc 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -998,7 +998,7 @@ class AscendMLAImpl(MLAAttentionImpl): decode_meta = attn_metadata.decode assert decode_meta is not None num_tokens = q_nope.size(0) - if self.running_in_graph: + if self.running_in_graph or self.running_chunkprefilll_with_torchair: # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] block_size = kv_c_and_k_pe_cache[0].shape[1] @@ -1112,6 +1112,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: kv_c, k_pe = self.kv_a_proj_with_mqa( @@ -1148,18 +1149,25 @@ class AscendMLAImpl(MLAAttentionImpl): if has_decode: decode_k_nope = None assert attn_metadata.decode is not None - if self.running_in_graph: + if self.running_in_graph or self.running_chunkprefilll_with_torchair: cos = attn_metadata.decode.cos sin = attn_metadata.decode.sin - with npu_stream_switch("mla_secondary", - 0, - enabled=enable_multistream_mla): - npu_wait_tensor(hidden_states_or_kv_c_normed, - ckq, - enabled=enable_multistream_mla) + if self.running_chunkprefilll_with_torchair: + decode_hs = ( + hidden_states_or_kv_c_normed[:num_decode_tokens]) + slots = attn_metadata.slot_mapping[:num_decode_tokens] decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + decode_hs, cos, sin, kv_cache, slots) + else: + with npu_stream_switch("mla_secondary", + 0, + enabled=enable_multistream_mla): + npu_wait_tensor(hidden_states_or_kv_c_normed, + ckq, + enabled=enable_multistream_mla) + decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) # Without explicitly controlling the order, IndexByTensor operations # would be placed after `matmul W_KV_T` hindering the overlapping of # KvRmsNormRopeCache and SingleRope. @@ -1183,6 +1191,8 @@ class AscendMLAImpl(MLAAttentionImpl): decode_k_pe, enabled=enable_multistream_mla) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + elif self.running_chunkprefilll_with_torchair: + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, @@ -1221,16 +1231,15 @@ class AscendMLAImpl(MLAAttentionImpl): kv_cache ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" if self.torchair_graph_enabled: - if kv_cache[0].numel( - ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if kv_cache[0].numel() > 0 and has_prefill: slots = attn_metadata.slot_mapping # NOTE: Separate the kv cache in advance to avoid OOM or other issues - torch_npu._npu_reshape_and_cache(key=kv_c_normed.view( - num_tokens, self.num_kv_heads, -1), - value=prefill_k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=slots) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1), + value=prefill_k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=slots[num_decode_tokens:]) else: kv_c_normed = kv_c_normed.view( [num_actual_toks, self.num_kv_heads, -1])