import os import gc import torch from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.logger import logger from vllm.platforms import current_platform from vllm.model_executor import set_random_seed from vllm.utils import GiB_bytes, MemorySnapshot from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.kv_cache_interface import KVCacheConfig import vllm_kunlun.platforms.envs as xenvs from vllm_kunlun.device_allocator.xpumem import XpuMemAllocator class KunlunWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def init_device(self): if self.device_config.device.type == "cuda": # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) self.device = torch.device(f"cuda:{self.local_rank}") current_platform.set_device(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype) # Initialize the distributed environment BEFORE taking # memory snapshot # This ensures NCCL buffers are allocated before we measure # available memory init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank, current_platform.dist_backend) # Set random seed. set_random_seed(self.model_config.seed) # Now take memory snapshot after NCCL is initialized gc.collect() torch.cuda.empty_cache() # take current memory snapshot self.init_snapshot = MemorySnapshot() if xenvs.VLLM_KUNLUN_ENABLE_VXPU: allocator = XpuMemAllocator.get_instance() free_mem, total_mem = allocator.get_pool_mem_info() self.init_snapshot.free_memory = free_mem self.init_snapshot.total_memory = free_mem self.init_snapshot.cuda_memory = total_mem - free_mem self.requested_memory = (self.init_snapshot.total_memory * self.cache_config.gpu_memory_utilization) if self.init_snapshot.free_memory < self.requested_memory: GiB = lambda b: round(b / GiB_bytes, 2) raise ValueError( f"Free memory on device " f"({GiB(self.init_snapshot.free_memory)}/" f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " f"is less than desired GPU memory utilization " f"({self.cache_config.gpu_memory_utilization}, " f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " f"utilization or reduce GPU memory used by other processes." ) else: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Construct the model runner self.model_runner: GPUModelRunner = GPUModelRunner( self.vllm_config, self.device) if self.rank == 0: # If usage stat is enabled, collect relevant info. report_usage_stats(self.vllm_config) def load_model(self) -> None: if xenvs.VLLM_KUNLUN_ENABLE_VXPU: allocator = XpuMemAllocator.get_instance() assert allocator.get_current_usage() == 0, ( "vXPU mode can only be " "used for one instance per process.") context = allocator.use_memory_pool(tag="weights") else: from contextlib import nullcontext context = nullcontext() with context: self.model_runner.load_model() def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" if xenvs.VLLM_KUNLUN_ENABLE_VXPU: allocator = XpuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: from contextlib import nullcontext context = nullcontext() with context: self.model_runner.initialize_kv_cache(kv_cache_config) def determine_available_memory(self) -> int: if xenvs.VLLM_KUNLUN_ENABLE_VXPU: allocator = XpuMemAllocator.get_instance() free, total = allocator.get_pool_mem_info() available_kv_cache_memory = int( total * self.cache_config.gpu_memory_utilization - (total - free) ) available_kv_cache_memory = int(max(available_kv_cache_memory, 0)) GiB = lambda b: b / GiB_bytes logger.info( f"Available memory (vxpu mode): {GiB(available_kv_cache_memory):.2f} GiB, total memory: {GiB(total):.2f} GiB" ) return available_kv_cache_memory else: return super().determine_available_memory() def offload_vram(self) -> None: allocator = XpuMemAllocator.get_instance() allocator.offload_vram(offload_tags=("weights",)) def try_reload_vram(self) -> tuple[bool, bool]: allocator = XpuMemAllocator.get_instance() return allocator.try_reload_vram(tags=None) def vxpu_unlock_gpu(self) -> None: allocator = XpuMemAllocator.get_instance() allocator.vxpu_unlock_gpu()