diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 3e064ec..537c700 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,7 +13,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.ops.cache import concat_and_cache_mla from vllm_ascend.worker.model_runner_v1 import NPUModelRunner if TYPE_CHECKING: @@ -214,7 +213,7 @@ class AscendMLAMetadataBuilder: block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() + device, non_blocking=True) input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() @@ -537,14 +536,14 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_q_pe.contiguous(), prefill_k_pe) if kv_cache.numel() > 0: - concat_and_cache_mla(k_c_normed, k_pe, kv_cache, - attn_metadata.slot_mapping.flatten()) - # TODO: replaced back to ascend ops - # key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2) - # torch_npu._npu_reshape_and_cache_siso( - # key=key, - # key_cache=kv_cache, - # slot_indices=attn_metadata.slot_mapping.flatten()) + key = torch.cat([ + k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe + ], + dim=2) + torch_npu._npu_reshape_and_cache_siso( + key=key, + key_cache=kv_cache, + slot_indices=attn_metadata.slot_mapping.flatten()) if has_prefill: output[num_decode_tokens:] = self._forward_prefill(