[BugFix] Fix a bug of running chunked-prefill with torchair. (#1378) (#1844)

This PR fixes the bug `local variable 'decode_hs_or_q_c' referenced
before assignment` when running chunked-prefill with torchair. We should
calculate `decode_hs_or_q_c` whether or not torchair graphics mode is
enabled.

backport of #1378
fix https://github.com/vllm-project/vllm-ascend/issues/1369


- vLLM version: v0.10.0
- vLLM main:
0e36abf993

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
Mengqing Cao
2025-07-31 20:08:45 +08:00
committed by GitHub
parent db310c6ec9
commit 4c8842da65
2 changed files with 32 additions and 22 deletions

View File

@@ -31,6 +31,7 @@ def _deepseek_torchair_test_fixture(
additional_config: Dict,
*,
tensor_parallel_size=2,
use_v1_schduler=False,
):
example_prompts = [
"Hello, my name is",
@@ -38,14 +39,14 @@ def _deepseek_torchair_test_fixture(
"The capital of France is",
"The future of AI is",
]
# torchair is only work without chunked-prefill now
kwargs = {
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}
kwargs = {}
if not use_v1_schduler:
kwargs = {
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}
additional_config.update(**kwargs)
with VllmRunner(
@@ -95,6 +96,15 @@ def test_e2e_deepseekv3_with_torchair_ms_mla():
_deepseek_torchair_test_fixture(additional_config)
def test_e2e_deepseekv3_with_torchair_v1scheduler():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True)
def _pangu_torchair_test_fixture(
additional_config: Dict,
*,

View File

@@ -1079,11 +1079,10 @@ class AscendMLAImpl(MLAAttentionImpl):
]
num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph:
if not self.torchair_graph_enabled:
kv_c, k_pe = self.kv_a_proj_with_mqa(
hidden_states_or_kv_c_normed)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
kv_c, k_pe = self.kv_a_proj_with_mqa(
hidden_states_or_kv_c_normed)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
else:
kv_c_normed = hidden_states_or_kv_c_normed
assert attn_metadata.num_decodes is not None and \
@@ -1102,12 +1101,13 @@ class AscendMLAImpl(MLAAttentionImpl):
if not self.running_in_graph:
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
if not self.torchair_graph_enabled:
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
k_pe = k_pe[:num_actual_toks, ...]
k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:]
# if not self.torchair_graph_enabled:
k_pe = k_pe[:num_actual_toks, ...]
k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
else:
decode_hs_or_q_c = hidden_states_or_q_c
if has_decode:
@@ -1167,11 +1167,11 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
prefill_hs, cos, sin, kv_cache,
attn_metadata.slot_mapping[num_decode_tokens:])
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
prefill_k_c_normed = prefill_k_nope
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
-1)
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)