From 1ccb9acd9ad48e5a857f7f06887ccf66911d0b20 Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Tue, 13 Jan 2026 09:53:26 +0800 Subject: [PATCH] [Refactor] Provide a framework to accommodate operators for different hardware devices (#5735) come from: https://github.com/vllm-project/vllm-ascend/issues/5463 Reason: During the iteration process of the hardware version, there may be a large number of iterations for the operators, which can lead to short-term compatibility differences. Therefore, an intermediate adaptation layer is provided to accommodate the short-term differences in operators. - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: weijinqian_v1 Signed-off-by: weijinqian0 <1184188277@qq.com> Co-authored-by: weijinqian_v1 --- vllm_ascend/attention/attention_v1.py | 35 ++++++----------- vllm_ascend/device/__init__.py | 0 vllm_ascend/device/device_op.py | 56 +++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 24 deletions(-) create mode 100644 vllm_ascend/device/__init__.py create mode 100644 vllm_ascend/device/device_op.py diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 8dc013f3..cbc880dd 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -43,9 +43,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.compilation.acl_graph import ( get_draft_graph_params, get_graph_params, update_draft_graph_params_workspaces, update_graph_params_workspaces) +from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager -from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, - weak_ref_tensors) +from vllm_ascend.utils import weak_ref_tensors # default max value of sliding window size SWA_INT_MAX = 2147483647 @@ -693,28 +693,15 @@ class AscendAttentionBackendImpl(AttentionImpl): self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER) - if get_ascend_device_type() == AscendDeviceType.A5: - # TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping. - # Should check if the 0 dim of slot_mapping must equal to the 0 dim of key. - # If it's necessary, the slots should be sliced. - torch_npu.npu_scatter_pa_kv_cache( - key=key[:attn_metadata.num_actual_tokens] - if not encoder_decoder else key, - value=value[:attn_metadata.num_actual_tokens].contiguous() - if not encoder_decoder else value, - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_mapping=slots) - else: - torch_npu._npu_reshape_and_cache( - key=key[:attn_metadata.num_actual_tokens] - if not encoder_decoder else key, - value=value[:attn_metadata.num_actual_tokens] - if not encoder_decoder else value, - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=slots[:attn_metadata.num_actual_tokens] - if not encoder_decoder else slots) + DeviceOperator.reshape_and_cache( + key=key[:attn_metadata.num_actual_tokens] + if not encoder_decoder else key, + value=value[:attn_metadata.num_actual_tokens] + if not encoder_decoder else value, + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_mapping=slots[:attn_metadata.num_actual_tokens] + if not encoder_decoder else slots) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() return key, value diff --git a/vllm_ascend/device/__init__.py b/vllm_ascend/device/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/device/device_op.py b/vllm_ascend/device/device_op.py new file mode 100644 index 00000000..ccd874dd --- /dev/null +++ b/vllm_ascend/device/device_op.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from typing import Optional, Type + +import torch_npu + +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type + + +class BaseDeviceAdaptor(object): + + @classmethod + def reshape_and_cache(cls, key, value, key_cache, value_cache, + slot_mapping): + torch_npu._npu_reshape_and_cache(key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_indices=slot_mapping) + + +class A5DeviceAdaptor(BaseDeviceAdaptor): + + @classmethod + def reshape_and_cache(cls, key, value, key_cache, value_cache, + slot_mapping): + torch_npu.npu_scatter_pa_kv_cache(key=key, + value=value.contiguous(), + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=slot_mapping) + + +def get_device_adaptor(): + ascend_device_type = get_ascend_device_type() + if ascend_device_type == AscendDeviceType.A5: + return A5DeviceAdaptor + return BaseDeviceAdaptor + + +DeviceOperator: Optional[Type['BaseDeviceAdaptor']] = get_device_adaptor()