130 lines
5.5 KiB
Python
130 lines
5.5 KiB
Python
import os
|
|
import gc
|
|
import torch
|
|
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
|
from vllm.logger import logger
|
|
from vllm.platforms import current_platform
|
|
from vllm.model_executor import set_random_seed
|
|
from vllm.utils import GiB_bytes, MemorySnapshot
|
|
from vllm.v1.utils import report_usage_stats
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
|
|
import vllm_kunlun.platforms.envs as xenvs
|
|
from vllm_kunlun.device_allocator.xpumem import XpuMemAllocator
|
|
|
|
|
|
class KunlunWorker(Worker):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def init_device(self):
|
|
if self.device_config.device.type == "cuda":
|
|
# This env var set by Ray causes exceptions with graph building.
|
|
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
|
self.device = torch.device(f"cuda:{self.local_rank}")
|
|
current_platform.set_device(self.device)
|
|
|
|
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
|
|
|
# Initialize the distributed environment BEFORE taking
|
|
# memory snapshot
|
|
# This ensures NCCL buffers are allocated before we measure
|
|
# available memory
|
|
init_worker_distributed_environment(self.vllm_config, self.rank,
|
|
self.distributed_init_method,
|
|
self.local_rank,
|
|
current_platform.dist_backend)
|
|
|
|
# Set random seed.
|
|
set_random_seed(self.model_config.seed)
|
|
|
|
# Now take memory snapshot after NCCL is initialized
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# take current memory snapshot
|
|
self.init_snapshot = MemorySnapshot()
|
|
if xenvs.VLLM_KUNLUN_ENABLE_VXPU:
|
|
allocator = XpuMemAllocator.get_instance()
|
|
free_mem, total_mem = allocator.get_pool_mem_info()
|
|
self.init_snapshot.free_memory = free_mem
|
|
self.init_snapshot.total_memory = free_mem
|
|
self.init_snapshot.cuda_memory = total_mem - free_mem
|
|
self.requested_memory = (self.init_snapshot.total_memory *
|
|
self.cache_config.gpu_memory_utilization)
|
|
if self.init_snapshot.free_memory < self.requested_memory:
|
|
GiB = lambda b: round(b / GiB_bytes, 2)
|
|
raise ValueError(
|
|
f"Free memory on device "
|
|
f"({GiB(self.init_snapshot.free_memory)}/"
|
|
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
|
|
f"is less than desired GPU memory utilization "
|
|
f"({self.cache_config.gpu_memory_utilization}, "
|
|
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
|
|
f"utilization or reduce GPU memory used by other processes."
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Not support device type: {self.device_config.device}")
|
|
|
|
# Construct the model runner
|
|
self.model_runner: GPUModelRunner = GPUModelRunner(
|
|
self.vllm_config, self.device)
|
|
|
|
if self.rank == 0:
|
|
# If usage stat is enabled, collect relevant info.
|
|
report_usage_stats(self.vllm_config)
|
|
|
|
def load_model(self) -> None:
|
|
if xenvs.VLLM_KUNLUN_ENABLE_VXPU:
|
|
allocator = XpuMemAllocator.get_instance()
|
|
assert allocator.get_current_usage() == 0, (
|
|
"vXPU mode can only be "
|
|
"used for one instance per process.")
|
|
context = allocator.use_memory_pool(tag="weights")
|
|
else:
|
|
from contextlib import nullcontext
|
|
context = nullcontext()
|
|
with context:
|
|
self.model_runner.load_model()
|
|
|
|
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
if xenvs.VLLM_KUNLUN_ENABLE_VXPU:
|
|
allocator = XpuMemAllocator.get_instance()
|
|
context = allocator.use_memory_pool(tag="kv_cache")
|
|
else:
|
|
from contextlib import nullcontext
|
|
context = nullcontext()
|
|
with context:
|
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
|
|
|
def determine_available_memory(self) -> int:
|
|
if xenvs.VLLM_KUNLUN_ENABLE_VXPU:
|
|
allocator = XpuMemAllocator.get_instance()
|
|
free, total = allocator.get_pool_mem_info()
|
|
available_kv_cache_memory = int(
|
|
total * self.cache_config.gpu_memory_utilization - (total - free)
|
|
)
|
|
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
|
|
GiB = lambda b: b / GiB_bytes
|
|
logger.info(
|
|
f"Available memory (vxpu mode): {GiB(available_kv_cache_memory):.2f} GiB, total memory: {GiB(total):.2f} GiB"
|
|
)
|
|
return available_kv_cache_memory
|
|
else:
|
|
return super().determine_available_memory()
|
|
|
|
def offload_vram(self) -> None:
|
|
allocator = XpuMemAllocator.get_instance()
|
|
allocator.offload_vram(offload_tags=("weights",))
|
|
|
|
def try_reload_vram(self) -> tuple[bool, bool]:
|
|
allocator = XpuMemAllocator.get_instance()
|
|
return allocator.try_reload_vram(tags=None)
|
|
|
|
def vxpu_unlock_gpu(self) -> None:
|
|
allocator = XpuMemAllocator.get_instance()
|
|
allocator.vxpu_unlock_gpu()
|