vllm-ascend vnpu v1

This commit is contained in:
starkwj
2025-12-26 07:37:35 +00:00
parent 2f1aed98cc
commit 135cc0a505
168 changed files with 28337 additions and 9 deletions

View File

@@ -20,10 +20,12 @@ import dataclasses
import os
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Tuple, Union
import time
import torch
from acl.rt import memcpy # type: ignore # noqa: F401
from acl.rt import memcpy, memset # type: ignore # noqa: F401
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.platform import NPUPlatform
@@ -56,8 +58,20 @@ def find_loaded_library(lib_name) -> Optional[str]:
camem_available = False
try:
from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401
init_module, python_create_and_map, python_unmap_and_release)
if envs_ascend.VLLM_ASCEND_ENABLE_IDLE_OFFLOAD:
from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401
init_module_offload as init_module,
python_create_and_map_offload as python_create_and_map,python_unmap_and_release_offload as python_unmap_and_release,
python_get_mem_info_offload as python_get_mem_info,
python_lock_gpu_offload as python_lock_gpu,
python_unlock_gpu_offload as python_unlock_gpu
)
else:
from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401
init_module, python_create_and_map, python_unmap_and_release)
python_get_mem_info = None
python_lock_gpu = None
python_unlock_gpu = None
lib_name = find_loaded_library("vllm_ascend_C")
camem_available = True
except ImportError as e:
@@ -66,6 +80,9 @@ except ImportError as e:
init_module = None
python_create_and_map = None
python_unmap_and_release = None
python_get_mem_info = None
python_lock_gpu = None
python_unlock_gpu = None
lib_name = None
libcudart = None
@@ -93,8 +110,14 @@ def get_pluggable_allocator(
python_free_func: Callable[[int], tuple[int, int, int, int]]
) -> torch.npu.memory.NPUPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, 'my_malloc',
'my_free')
if envs_ascend.VLLM_ASCEND_ENABLE_IDLE_OFFLOAD:
new_alloc = torch.npu.memory.NPUPluggableAllocator(
lib_name, 'my_malloc_offload', 'my_free_offload'
)
else:
new_alloc = torch.npu.memory.NPUPluggableAllocator(
lib_name, 'my_malloc', 'my_free'
)
return new_alloc
@@ -153,6 +176,7 @@ class CaMemAllocator:
self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CaMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
# self.requested_vram_size = 0
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
@@ -254,6 +278,9 @@ class CaMemAllocator:
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
# lock gpu
if envs_ascend.VLLM_ASCEND_ENABLE_IDLE_OFFLOAD:
self.vnpu_lock_gpu()
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
@@ -265,6 +292,8 @@ class CaMemAllocator:
# allocate memory.
# TODO: we need to find a way to release the memory,
# i.e. calling torch.cuda.empty_cache()
if envs_ascend.VLLM_ASCEND_ENABLE_IDLE_OFFLOAD:
self.vnpu_unlock_gpu()
self.current_tag = old_tag
def get_current_usage(self) -> int:
@@ -276,3 +305,100 @@ class CaMemAllocator:
handle = data.handle
sum_bytes += handle[1]
return sum_bytes
def vnpu_lock_gpu(self) -> bool:
if python_lock_gpu:
return python_lock_gpu()
else:
return False
def vnpu_unlock_gpu(self):
if python_unlock_gpu:
python_unlock_gpu()
def get_pool_mem_info(self) -> int:
"""
get available memory in reserved pool."""
return python_get_mem_info()
def offload_vram(
self,
offload_tags: Optional[Union[Tuple[str, ...],
str]] = None) -> None:
"""
Put the allocator in sleep mode.
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
# by default, allocated tensors are offloaded
# when the allocator sleeps
offload_tags = (CaMemAllocator.default_tag, )
elif isinstance(offload_tags, str):
offload_tags = (offload_tags, )
assert isinstance(offload_tags, tuple)
sz_weights = 0
sz_kvcache = 0
for ptr, data in self.pointer_to_data.items():
handle = data.handle
if data.tag in offload_tags:
size_in_bytes = handle[1]
if data.cpu_backup_tensor is None:
cpu_backup_tensor = torch.empty(
size_in_bytes,
dtype=torch.uint8,
device='cpu',
pin_memory=NPUPlatform.is_pin_memory_available())
cpu_ptr = cpu_backup_tensor.data_ptr()
ACL_MEMCPY_DEVICE_TO_HOST = 2
dest_max = cpu_ptr + size_in_bytes * 2
memcpy(cpu_ptr, dest_max, ptr, size_in_bytes,
ACL_MEMCPY_DEVICE_TO_HOST)
data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle)
sz_weights += size_in_bytes
else:
size_in_bytes = handle[1]
unmap_and_release(handle)
sz_kvcache += size_in_bytes
# self.requested_vram_size = sz_weights + sz_kvcache
self.vnpu_unlock_gpu()
# logger.info(f"offload: tags {offload_tags}: {sz_weights/(1024**3):.2f} GB, discard kv cache: {sz_kvcache/(1024**3):.2f} GB")
def reload_vram(self, tags: Optional[list[str]] = None) -> bool:
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory."""
prev_is_self = self.vnpu_lock_gpu()
if prev_is_self:
# nothing to do
return True
for ptr, data in self.pointer_to_data.items():
handle = data.handle
if tags is None or data.tag in tags:
create_and_map(handle)
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
size_in_bytes = cpu_backup_tensor.numel(
) * cpu_backup_tensor.element_size()
cpu_ptr = cpu_backup_tensor.data_ptr()
ACL_MEMCPY_HOST_TO_DEVICE = 1
dest_max = ptr + size_in_bytes * 2
memcpy(ptr, dest_max, cpu_ptr, size_in_bytes,
ACL_MEMCPY_HOST_TO_DEVICE)
# data.cpu_backup_tensor = None
# TO check: no need to re-memset if we reset_prefix_cache
# else:
# size_in_bytes = handle[1]
# memset(ptr, size_in_bytes, 0, size_in_bytes)
return False