### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/attention_mask.py` |
| `vllm_ascend/attention/attention_v1.py` |
| `vllm_ascend/attention/context_parallel/attention_cp.py` |
| `vllm_ascend/attention/context_parallel/common_cp.py` |
| `vllm_ascend/attention/context_parallel/mla_cp.py` |
| `vllm_ascend/attention/utils.py` |
| `vllm_ascend/batch_invariant.py` |
| `vllm_ascend/device/device_op.py` |
| `vllm_ascend/device_allocator/camem.py` |
| `vllm_ascend/envs.py` |
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -15,35 +15,26 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
|
||||
class BaseDeviceAdaptor(object):
|
||||
|
||||
class BaseDeviceAdaptor:
|
||||
@classmethod
|
||||
def reshape_and_cache(cls, key, value, key_cache, value_cache,
|
||||
slot_mapping):
|
||||
torch_npu._npu_reshape_and_cache(key=key,
|
||||
value=value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
slot_indices=slot_mapping)
|
||||
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key, value=value, key_cache=key_cache, value_cache=value_cache, slot_indices=slot_mapping
|
||||
)
|
||||
|
||||
|
||||
class A5DeviceAdaptor(BaseDeviceAdaptor):
|
||||
|
||||
@classmethod
|
||||
def reshape_and_cache(cls, key, value, key_cache, value_cache,
|
||||
slot_mapping):
|
||||
torch_npu.npu_scatter_pa_kv_cache(key=key,
|
||||
value=value.contiguous(),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
slot_mapping=slot_mapping)
|
||||
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
|
||||
torch_npu.npu_scatter_pa_kv_cache(
|
||||
key=key, value=value.contiguous(), key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping
|
||||
)
|
||||
|
||||
|
||||
def get_device_adaptor():
|
||||
@@ -53,4 +44,4 @@ def get_device_adaptor():
|
||||
return BaseDeviceAdaptor
|
||||
|
||||
|
||||
DeviceOperator: Optional[Type['BaseDeviceAdaptor']] = get_device_adaptor()
|
||||
DeviceOperator: type["BaseDeviceAdaptor"] | None = get_device_adaptor()
|
||||
|
||||
Reference in New Issue
Block a user