[Feat] chunkprefill mla support torchair graph (#1772)
chunkprefill mla only support eager mode now,we want to optimaze it by
support torchair graph, the idea is simple, when all the request is
running in decode, use torchair graph to deal with it, else when
chunkprefill or prefill only, use the eager mode
- vLLM version: v0.10.0
- vLLM main:
ebf7605b0d
Signed-off-by: haojiangzheng <justineric096@gmail.com>
Co-authored-by: haojiangzheng <justineric096@gmail.com>
This commit is contained in:
@@ -664,6 +664,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
def test_forward_decode_without_graph(self, mock_page_attention_mla,
|
||||
mock_up_proj):
|
||||
self.impl.running_in_graph = False
|
||||
self.impl.running_chunkprefilll_with_torchair = False
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
|
||||
@@ -998,7 +998,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
num_tokens = q_nope.size(0)
|
||||
if self.running_in_graph:
|
||||
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
||||
@@ -1112,6 +1112,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
if k_pe is None and not self.running_in_graph:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
@@ -1148,18 +1149,25 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if has_decode:
|
||||
decode_k_nope = None
|
||||
assert attn_metadata.decode is not None
|
||||
if self.running_in_graph:
|
||||
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=enable_multistream_mla):
|
||||
npu_wait_tensor(hidden_states_or_kv_c_normed,
|
||||
ckq,
|
||||
enabled=enable_multistream_mla)
|
||||
if self.running_chunkprefilll_with_torchair:
|
||||
decode_hs = (
|
||||
hidden_states_or_kv_c_normed[:num_decode_tokens])
|
||||
slots = attn_metadata.slot_mapping[:num_decode_tokens]
|
||||
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
|
||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping)
|
||||
decode_hs, cos, sin, kv_cache, slots)
|
||||
else:
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=enable_multistream_mla):
|
||||
npu_wait_tensor(hidden_states_or_kv_c_normed,
|
||||
ckq,
|
||||
enabled=enable_multistream_mla)
|
||||
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
|
||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping)
|
||||
# Without explicitly controlling the order, IndexByTensor operations
|
||||
# would be placed after `matmul W_KV_T` hindering the overlapping of
|
||||
# KvRmsNormRopeCache and SingleRope.
|
||||
@@ -1183,6 +1191,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
decode_k_pe,
|
||||
enabled=enable_multistream_mla)
|
||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||
elif self.running_chunkprefilll_with_torchair:
|
||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||
else:
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions,
|
||||
@@ -1221,16 +1231,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_cache
|
||||
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
|
||||
if self.torchair_graph_enabled:
|
||||
if kv_cache[0].numel(
|
||||
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
if kv_cache[0].numel() > 0 and has_prefill:
|
||||
slots = attn_metadata.slot_mapping
|
||||
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
||||
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
||||
num_tokens, self.num_kv_heads, -1),
|
||||
value=prefill_k_pe,
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=slots)
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
|
||||
value=prefill_k_pe,
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=slots[num_decode_tokens:])
|
||||
else:
|
||||
kv_c_normed = kv_c_normed.view(
|
||||
[num_actual_toks, self.num_kv_heads, -1])
|
||||
|
||||
Reference in New Issue
Block a user