[BugFix][P/D][0.18.0]bugfix short squence has no respone (#8142)
### What this PR does / why we need it? bugfix short squence has no respone. This pull request refactors the event handling for KV cache reshaping in mla_v1.py by centralizing the reshape_cache_event creation and recording within the _mla_preprocess function, ensuring it covers both decode and prefill operations. Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -722,6 +722,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.is_kv_producer = (
|
||||
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
)
|
||||
self.is_kv_both = (
|
||||
self.vllm_config.kv_transfer_config is not None
|
||||
and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
||||
)
|
||||
self.layer_name = kwargs.get("layer_name")
|
||||
self.fa_quant_layer = enable_fa_quant(self.vllm_config, self.layer_name)
|
||||
self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype
|
||||
@@ -1498,11 +1503,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
sin = attn_metadata.prefill.sin
|
||||
prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens]
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
prefill_k_nope, prefill_value = (
|
||||
self.kv_b_proj(prefill_k_c_normed)[0]
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
@@ -1566,11 +1567,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
# Preprocess for decode tokens
|
||||
if self.is_kv_producer and not self.is_kv_both:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
if has_decode:
|
||||
decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata)
|
||||
# Preprocess for prefill tokens
|
||||
if has_prefill:
|
||||
prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata)
|
||||
if self.is_kv_producer and not self.is_kv_both:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
return decode_preprocess_res, prefill_preprocess_res
|
||||
|
||||
def get_num_actual_tokens(self, attn_metadata: M):
|
||||
@@ -1622,7 +1627,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
and attn_metadata.num_decode_tokens is not None
|
||||
)
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
@@ -1682,6 +1686,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
del o_proj_input
|
||||
|
||||
if has_prefill:
|
||||
if self.is_kv_producer and not self.is_kv_both:
|
||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||
return output_padded
|
||||
|
||||
Reference in New Issue
Block a user