From 4c8842da659b31af8aad1905b8d7a4f0736283ea Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Thu, 31 Jul 2025 20:08:45 +0800 Subject: [PATCH] [BugFix] Fix a bug of running chunked-prefill with torchair. (#1378) (#1844) This PR fixes the bug `local variable 'decode_hs_or_q_c' referenced before assignment` when running chunked-prefill with torchair. We should calculate `decode_hs_or_q_c` whether or not torchair graphics mode is enabled. backport of #1378 fix https://github.com/vllm-project/vllm-ascend/issues/1369 - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/0e36abf9931baa070609376debb4fb3772f4a3fe --------- Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: MengqingCao Co-authored-by: whx-sjtu <2952154980@qq.com> --- .../e2e/multicard/test_torchair_graph_mode.py | 26 +++++++++++------ vllm_ascend/attention/mla_v1.py | 28 +++++++++---------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index 9ad336c..71d33f0 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -31,6 +31,7 @@ def _deepseek_torchair_test_fixture( additional_config: Dict, *, tensor_parallel_size=2, + use_v1_schduler=False, ): example_prompts = [ "Hello, my name is", @@ -38,14 +39,14 @@ def _deepseek_torchair_test_fixture( "The capital of France is", "The future of AI is", ] - - # torchair is only work without chunked-prefill now - kwargs = { - "ascend_scheduler_config": { - "enabled": True, - }, - "refresh": True, - } + kwargs = {} + if not use_v1_schduler: + kwargs = { + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } additional_config.update(**kwargs) with VllmRunner( @@ -95,6 +96,15 @@ def test_e2e_deepseekv3_with_torchair_ms_mla(): _deepseek_torchair_test_fixture(additional_config) +def test_e2e_deepseekv3_with_torchair_v1scheduler(): + additional_config = { + "torchair_graph_config": { + "enabled": True, + }, + } + _deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True) + + def _pangu_torchair_test_fixture( additional_config: Dict, *, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 4e24756..7771632 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1079,11 +1079,10 @@ class AscendMLAImpl(MLAAttentionImpl): ] num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: - if not self.torchair_graph_enabled: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else: kv_c_normed = hidden_states_or_kv_c_normed assert attn_metadata.num_decodes is not None and \ @@ -1102,12 +1101,13 @@ class AscendMLAImpl(MLAAttentionImpl): if not self.running_in_graph: hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - if not self.torchair_graph_enabled: - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] + # if not self.torchair_graph_enabled: + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] else: decode_hs_or_q_c = hidden_states_or_q_c if has_decode: @@ -1167,11 +1167,11 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + prefill_hs, cos, sin, kv_cache, + attn_metadata.slot_mapping[num_decode_tokens:]) kv_c_normed = prefill_k_nope[:num_actual_toks, ...] - prefill_k_c_normed = prefill_k_nope[num_decode_tokens:] + prefill_k_c_normed = prefill_k_nope prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)