This PR fixes the bug `local variable 'decode_hs_or_q_c' referenced
before assignment` when running chunked-prefill with torchair. We should
calculate `decode_hs_or_q_c` whether or not torchair graphics mode is
enabled.
backport of #1378
fix https://github.com/vllm-project/vllm-ascend/issues/1369
- vLLM version: v0.10.0
- vLLM main:
0e36abf993
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -31,6 +31,7 @@ def _deepseek_torchair_test_fixture(
|
||||
additional_config: Dict,
|
||||
*,
|
||||
tensor_parallel_size=2,
|
||||
use_v1_schduler=False,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@@ -38,14 +39,14 @@ def _deepseek_torchair_test_fixture(
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# torchair is only work without chunked-prefill now
|
||||
kwargs = {
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
kwargs = {}
|
||||
if not use_v1_schduler:
|
||||
kwargs = {
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
additional_config.update(**kwargs)
|
||||
|
||||
with VllmRunner(
|
||||
@@ -95,6 +96,15 @@ def test_e2e_deepseekv3_with_torchair_ms_mla():
|
||||
_deepseek_torchair_test_fixture(additional_config)
|
||||
|
||||
|
||||
def test_e2e_deepseekv3_with_torchair_v1scheduler():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True)
|
||||
|
||||
|
||||
def _pangu_torchair_test_fixture(
|
||||
additional_config: Dict,
|
||||
*,
|
||||
|
||||
@@ -1079,11 +1079,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
]
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
if k_pe is None and not self.running_in_graph:
|
||||
if not self.torchair_graph_enabled:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
else:
|
||||
kv_c_normed = hidden_states_or_kv_c_normed
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
@@ -1102,12 +1101,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if not self.running_in_graph:
|
||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
if not self.torchair_graph_enabled:
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:]
|
||||
# if not self.torchair_graph_enabled:
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
else:
|
||||
decode_hs_or_q_c = hidden_states_or_q_c
|
||||
if has_decode:
|
||||
@@ -1167,11 +1167,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping)
|
||||
prefill_hs, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping[num_decode_tokens:])
|
||||
|
||||
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
|
||||
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
|
||||
prefill_k_c_normed = prefill_k_nope
|
||||
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
||||
-1)
|
||||
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user