[BugFix][P/D][0.18.0]bugfix short squence has no respone (#8142)

### What this PR does / why we need it?
bugfix short squence has no respone. This pull request refactors the
event handling for KV cache reshaping in mla_v1.py by centralizing the
reshape_cache_event creation and recording within the _mla_preprocess
function, ensuring it covers both decode and prefill operations.

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
wangxiaoteng888
2026-04-12 23:25:01 +08:00
committed by GitHub
parent 31186a3a9d
commit 4adc6a68f5

View File

@@ -722,6 +722,11 @@ class AscendMLAImpl(MLAAttentionImpl):
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
self.is_kv_both = (
self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.is_kv_producer
and self.vllm_config.kv_transfer_config.is_kv_consumer
)
self.layer_name = kwargs.get("layer_name")
self.fa_quant_layer = enable_fa_quant(self.vllm_config, self.layer_name)
self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype
@@ -1498,11 +1503,7 @@ class AscendMLAImpl(MLAAttentionImpl):
sin = attn_metadata.prefill.sin
prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
prefill_k_nope, prefill_value = (
self.kv_b_proj(prefill_k_c_normed)[0]
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
@@ -1566,11 +1567,15 @@ class AscendMLAImpl(MLAAttentionImpl):
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
# Preprocess for decode tokens
if self.is_kv_producer and not self.is_kv_both:
attn_metadata.reshape_cache_event = torch.npu.Event()
if has_decode:
decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata)
# Preprocess for prefill tokens
if has_prefill:
prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata)
if self.is_kv_producer and not self.is_kv_both:
attn_metadata.reshape_cache_event.record()
return decode_preprocess_res, prefill_preprocess_res
def get_num_actual_tokens(self, attn_metadata: M):
@@ -1622,7 +1627,6 @@ class AscendMLAImpl(MLAAttentionImpl):
and attn_metadata.num_decode_tokens is not None
)
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
@@ -1682,6 +1686,6 @@ class AscendMLAImpl(MLAAttentionImpl):
del o_proj_input
if has_prefill:
if self.is_kv_producer and not self.is_kv_both:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded