This PR fixes the bug `local variable 'decode_hs_or_q_c' referenced
before assignment` when running chunked-prefill with torchair. We should
calculate `decode_hs_or_q_c` whether or not torchair graphics mode is
enabled.
backport of #1378
fix https://github.com/vllm-project/vllm-ascend/issues/1369
- vLLM version: v0.10.0
- vLLM main:
0e36abf993
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -31,6 +31,7 @@ def _deepseek_torchair_test_fixture(
|
|||||||
additional_config: Dict,
|
additional_config: Dict,
|
||||||
*,
|
*,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
|
use_v1_schduler=False,
|
||||||
):
|
):
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -38,14 +39,14 @@ def _deepseek_torchair_test_fixture(
|
|||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
|
kwargs = {}
|
||||||
# torchair is only work without chunked-prefill now
|
if not use_v1_schduler:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"ascend_scheduler_config": {
|
"ascend_scheduler_config": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
},
|
},
|
||||||
"refresh": True,
|
"refresh": True,
|
||||||
}
|
}
|
||||||
additional_config.update(**kwargs)
|
additional_config.update(**kwargs)
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
@@ -95,6 +96,15 @@ def test_e2e_deepseekv3_with_torchair_ms_mla():
|
|||||||
_deepseek_torchair_test_fixture(additional_config)
|
_deepseek_torchair_test_fixture(additional_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_e2e_deepseekv3_with_torchair_v1scheduler():
|
||||||
|
additional_config = {
|
||||||
|
"torchair_graph_config": {
|
||||||
|
"enabled": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True)
|
||||||
|
|
||||||
|
|
||||||
def _pangu_torchair_test_fixture(
|
def _pangu_torchair_test_fixture(
|
||||||
additional_config: Dict,
|
additional_config: Dict,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -1079,11 +1079,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
]
|
]
|
||||||
num_actual_toks = attn_metadata.num_actual_tokens
|
num_actual_toks = attn_metadata.num_actual_tokens
|
||||||
if k_pe is None and not self.running_in_graph:
|
if k_pe is None and not self.running_in_graph:
|
||||||
if not self.torchair_graph_enabled:
|
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
hidden_states_or_kv_c_normed)[0].split(
|
||||||
hidden_states_or_kv_c_normed)[0].split(
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
|
||||||
else:
|
else:
|
||||||
kv_c_normed = hidden_states_or_kv_c_normed
|
kv_c_normed = hidden_states_or_kv_c_normed
|
||||||
assert attn_metadata.num_decodes is not None and \
|
assert attn_metadata.num_decodes is not None and \
|
||||||
@@ -1102,12 +1101,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
if not self.running_in_graph:
|
if not self.running_in_graph:
|
||||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||||
if not self.torchair_graph_enabled:
|
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:]
|
||||||
k_pe = k_pe[:num_actual_toks, ...]
|
# if not self.torchair_graph_enabled:
|
||||||
k_pe = k_pe.unsqueeze(1)
|
k_pe = k_pe[:num_actual_toks, ...]
|
||||||
decode_k_pe = k_pe[:num_decode_tokens]
|
k_pe = k_pe.unsqueeze(1)
|
||||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
decode_k_pe = k_pe[:num_decode_tokens]
|
||||||
|
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||||
else:
|
else:
|
||||||
decode_hs_or_q_c = hidden_states_or_q_c
|
decode_hs_or_q_c = hidden_states_or_q_c
|
||||||
if has_decode:
|
if has_decode:
|
||||||
@@ -1167,11 +1167,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||||
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
||||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
prefill_hs, cos, sin, kv_cache,
|
||||||
attn_metadata.slot_mapping)
|
attn_metadata.slot_mapping[num_decode_tokens:])
|
||||||
|
|
||||||
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
|
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
|
||||||
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
|
prefill_k_c_normed = prefill_k_nope
|
||||||
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
||||||
-1)
|
-1)
|
||||||
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user