[BugFix][P/D][0.18.0]bugfix short squence has no respone (#8142)
### What this PR does / why we need it? bugfix short squence has no respone. This pull request refactors the event handling for KV cache reshaping in mla_v1.py by centralizing the reshape_cache_event creation and recording within the _mla_preprocess function, ensuring it covers both decode and prefill operations. Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -722,6 +722,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.is_kv_producer = (
|
self.is_kv_producer = (
|
||||||
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
)
|
)
|
||||||
|
self.is_kv_both = (
|
||||||
|
self.vllm_config.kv_transfer_config is not None
|
||||||
|
and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
|
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
||||||
|
)
|
||||||
self.layer_name = kwargs.get("layer_name")
|
self.layer_name = kwargs.get("layer_name")
|
||||||
self.fa_quant_layer = enable_fa_quant(self.vllm_config, self.layer_name)
|
self.fa_quant_layer = enable_fa_quant(self.vllm_config, self.layer_name)
|
||||||
self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype
|
self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype
|
||||||
@@ -1498,11 +1503,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
sin = attn_metadata.prefill.sin
|
sin = attn_metadata.prefill.sin
|
||||||
prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens]
|
prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens]
|
||||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||||
if self.is_kv_producer:
|
|
||||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
|
||||||
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
||||||
if self.is_kv_producer:
|
|
||||||
attn_metadata.reshape_cache_event.record()
|
|
||||||
prefill_k_nope, prefill_value = (
|
prefill_k_nope, prefill_value = (
|
||||||
self.kv_b_proj(prefill_k_c_normed)[0]
|
self.kv_b_proj(prefill_k_c_normed)[0]
|
||||||
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
@@ -1566,11 +1567,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
if has_prefill:
|
if has_prefill:
|
||||||
wait_for_kv_layer_from_connector(layer_name)
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
# Preprocess for decode tokens
|
# Preprocess for decode tokens
|
||||||
|
if self.is_kv_producer and not self.is_kv_both:
|
||||||
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata)
|
decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata)
|
||||||
# Preprocess for prefill tokens
|
# Preprocess for prefill tokens
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata)
|
prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata)
|
||||||
|
if self.is_kv_producer and not self.is_kv_both:
|
||||||
|
attn_metadata.reshape_cache_event.record()
|
||||||
return decode_preprocess_res, prefill_preprocess_res
|
return decode_preprocess_res, prefill_preprocess_res
|
||||||
|
|
||||||
def get_num_actual_tokens(self, attn_metadata: M):
|
def get_num_actual_tokens(self, attn_metadata: M):
|
||||||
@@ -1622,7 +1627,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
and attn_metadata.num_decode_tokens is not None
|
and attn_metadata.num_decode_tokens is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
output_padded = output
|
output_padded = output
|
||||||
@@ -1682,6 +1686,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
del o_proj_input
|
del o_proj_input
|
||||||
|
|
||||||
if has_prefill:
|
if self.is_kv_producer and not self.is_kv_both:
|
||||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|||||||
Reference in New Issue
Block a user