diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index aa93d460..ef4220a5 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -40,6 +40,7 @@ from vllm_ascend.compilation.acl_graph import ( update_draft_graph_params_workspaces, update_graph_params_workspaces, ) +from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.ops.layer_shard_linear import ( is_hidden_layer, post_process_after_loading_for_shard_weight_series, @@ -1075,12 +1076,12 @@ class AscendMLAImpl(MLAAttentionImpl): kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device) k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device) - torch_npu.atb.npu_paged_cache_load( + DeviceOperator.mla_cache_load( cache_kv_c, cache_k_pe, prefill_metadata.block_table, context_seq_len_npu, - seq_starts=prefill_metadata.chunked_context.starts[i], + prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, ) diff --git a/vllm_ascend/device/device_op.py b/vllm_ascend/device/device_op.py index 5d95544c..3a2b4fa1 100644 --- a/vllm_ascend/device/device_op.py +++ b/vllm_ascend/device/device_op.py @@ -171,6 +171,18 @@ class BaseDeviceAdaptor: output_dtype=fallback_output_dtype, )[0] + @staticmethod + def mla_cache_load(cache_kv_c, cache_k_pe, block_table, context_seq_len_npu, seq_starts, key, value): + torch_npu.atb.npu_paged_cache_load( + cache_kv_c, + cache_k_pe, + block_table, + context_seq_len_npu, + seq_starts=seq_starts, + key=key, + value=value, + ) + class A5DeviceAdaptor(BaseDeviceAdaptor): @classmethod @@ -375,6 +387,18 @@ class A5DeviceAdaptor(BaseDeviceAdaptor): **gmm2_kwargs, )[0] + @staticmethod + def mla_cache_load(cache_kv_c, cache_k_pe, block_table, context_seq_len_npu, seq_offset, key, value): + torch_npu.npu_gather_pa_kv_cache( + cache_kv_c, + cache_k_pe, + block_table, + context_seq_len_npu, + seq_offset=seq_offset, + key=key, + value=value, + ) + def get_device_adaptor() -> type["BaseDeviceAdaptor"]: ascend_device_type = get_ascend_device_type()