This PR fixes the bug `local variable 'decode_hs_or_q_c' referenced
before assignment` when running chunked-prefill with torchair. We should
calculate `decode_hs_or_q_c` whether or not torchair graphics mode is
enabled.
backport of #1378
fix https://github.com/vllm-project/vllm-ascend/issues/1369
- vLLM version: v0.10.0
- vLLM main:
0e36abf993
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -1079,11 +1079,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
]
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
if k_pe is None and not self.running_in_graph:
|
||||
if not self.torchair_graph_enabled:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
else:
|
||||
kv_c_normed = hidden_states_or_kv_c_normed
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
@@ -1102,12 +1101,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if not self.running_in_graph:
|
||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
if not self.torchair_graph_enabled:
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:]
|
||||
# if not self.torchair_graph_enabled:
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
else:
|
||||
decode_hs_or_q_c = hidden_states_or_q_c
|
||||
if has_decode:
|
||||
@@ -1167,11 +1167,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping)
|
||||
prefill_hs, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping[num_decode_tokens:])
|
||||
|
||||
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
|
||||
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
|
||||
prefill_k_c_normed = prefill_k_nope
|
||||
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
||||
-1)
|
||||
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user