[Refactor] Provide a framework to accommodate operators for different hardware devices (#5735)
come from: https://github.com/vllm-project/vllm-ascend/issues/5463
Reason:
During the iteration process of the hardware version, there may be a
large number of iterations for the operators, which can lead to
short-term compatibility differences. Therefore, an intermediate
adaptation layer is provided to accommodate the short-term differences
in operators.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: weijinqian0 <1184188277@qq.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -43,9 +43,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
get_draft_graph_params, get_graph_params,
|
||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
weak_ref_tensors)
|
||||
from vllm_ascend.utils import weak_ref_tensors
|
||||
|
||||
# default max value of sliding window size
|
||||
SWA_INT_MAX = 2147483647
|
||||
@@ -693,28 +693,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER)
|
||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
||||
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
||||
# If it's necessary, the slots should be sliced.
|
||||
torch_npu.npu_scatter_pa_kv_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else key,
|
||||
value=value[:attn_metadata.num_actual_tokens].contiguous()
|
||||
if not encoder_decoder else value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_mapping=slots)
|
||||
else:
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else key,
|
||||
value=value[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else slots)
|
||||
DeviceOperator.reshape_and_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else key,
|
||||
value=value[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_mapping=slots[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else slots)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
return key, value
|
||||
|
||||
Reference in New Issue
Block a user