diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 0e62df4b..7a704a9b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -722,6 +722,11 @@ class AscendMLAImpl(MLAAttentionImpl): self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) + self.is_kv_both = ( + self.vllm_config.kv_transfer_config is not None + and self.vllm_config.kv_transfer_config.is_kv_producer + and self.vllm_config.kv_transfer_config.is_kv_consumer + ) self.layer_name = kwargs.get("layer_name") self.fa_quant_layer = enable_fa_quant(self.vllm_config, self.layer_name) self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype @@ -1498,11 +1503,7 @@ class AscendMLAImpl(MLAAttentionImpl): sin = attn_metadata.prefill.sin prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens] prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) - if self.is_kv_producer: - attn_metadata.reshape_cache_event = torch.npu.Event() prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) - if self.is_kv_producer: - attn_metadata.reshape_cache_event.record() prefill_k_nope, prefill_value = ( self.kv_b_proj(prefill_k_c_normed)[0] .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) @@ -1566,11 +1567,15 @@ class AscendMLAImpl(MLAAttentionImpl): if has_prefill: wait_for_kv_layer_from_connector(layer_name) # Preprocess for decode tokens + if self.is_kv_producer and not self.is_kv_both: + attn_metadata.reshape_cache_event = torch.npu.Event() if has_decode: decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata) # Preprocess for prefill tokens if has_prefill: prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata) + if self.is_kv_producer and not self.is_kv_both: + attn_metadata.reshape_cache_event.record() return decode_preprocess_res, prefill_preprocess_res def get_num_actual_tokens(self, attn_metadata: M): @@ -1622,7 +1627,6 @@ class AscendMLAImpl(MLAAttentionImpl): and attn_metadata.num_decode_tokens is not None ) - has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output @@ -1682,6 +1686,6 @@ class AscendMLAImpl(MLAAttentionImpl): del o_proj_input - if has_prefill: + if self.is_kv_producer and not self.is_kv_both: maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) return output_padded