Files
xc-llm-kunlun/vllm_kunlun/v1/worker/worker_v1.py
2026-02-12 10:46:37 +08:00

130 lines
5.5 KiB
Python

import os
import gc
import torch
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm.model_executor import set_random_seed
from vllm.utils import GiB_bytes, MemorySnapshot
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.kv_cache_interface import KVCacheConfig
import vllm_kunlun.platforms.envs as xenvs
from vllm_kunlun.device_allocator.xpumem import XpuMemAllocator
class KunlunWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def init_device(self):
if self.device_config.device.type == "cuda":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment BEFORE taking
# memory snapshot
# This ensures NCCL buffers are allocated before we measure
# available memory
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend)
# Set random seed.
set_random_seed(self.model_config.seed)
# Now take memory snapshot after NCCL is initialized
gc.collect()
torch.cuda.empty_cache()
# take current memory snapshot
self.init_snapshot = MemorySnapshot()
if xenvs.VLLM_KUNLUN_ENABLE_VXPU:
allocator = XpuMemAllocator.get_instance()
free_mem, total_mem = allocator.get_pool_mem_info()
self.init_snapshot.free_memory = free_mem
self.init_snapshot.total_memory = free_mem
self.init_snapshot.cuda_memory = total_mem - free_mem
self.requested_memory = (self.init_snapshot.total_memory *
self.cache_config.gpu_memory_utilization)
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Construct the model runner
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
def load_model(self) -> None:
if xenvs.VLLM_KUNLUN_ENABLE_VXPU:
allocator = XpuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"vXPU mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if xenvs.VLLM_KUNLUN_ENABLE_VXPU:
allocator = XpuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def determine_available_memory(self) -> int:
if xenvs.VLLM_KUNLUN_ENABLE_VXPU:
allocator = XpuMemAllocator.get_instance()
free, total = allocator.get_pool_mem_info()
available_kv_cache_memory = int(
total * self.cache_config.gpu_memory_utilization - (total - free)
)
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
GiB = lambda b: b / GiB_bytes
logger.info(
f"Available memory (vxpu mode): {GiB(available_kv_cache_memory):.2f} GiB, total memory: {GiB(total):.2f} GiB"
)
return available_kv_cache_memory
else:
return super().determine_available_memory()
def offload_vram(self) -> None:
allocator = XpuMemAllocator.get_instance()
allocator.offload_vram(offload_tags=("weights",))
def try_reload_vram(self) -> tuple[bool, bool]:
allocator = XpuMemAllocator.get_instance()
return allocator.try_reload_vram(tags=None)
def vxpu_unlock_gpu(self) -> None:
allocator = XpuMemAllocator.get_instance()
allocator.vxpu_unlock_gpu()