# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc import os from typing import Any import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.profiler.wrapper import TorchProfilerWrapper from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.utils.torch_utils import set_random_seed from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.xpu_model_runner import XPUModelRunner from .utils import request_memory logger = init_logger(__name__) class XPUWorker(Worker): """A XPU worker class.""" def __init__( self, vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, is_driver_worker: bool = False, ): super().__init__( vllm_config, local_rank, rank, distributed_init_method, is_driver_worker ) device_config = self.device_config assert device_config.device_type == "xpu" assert current_platform.is_xpu() # Torch profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None profiler_config = vllm_config.profiler_config if profiler_config.profiler == "torch": worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" self.profiler = TorchProfilerWrapper( profiler_config, worker_name=worker_name, local_rank=self.local_rank, activities=["CPU", "XPU"], ) def init_device(self): device = self.device_config.device if ( isinstance(device, torch.device) and device.type == "xpu" and current_platform.is_xpu() ): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( self.local_rank ).total_memory else: raise RuntimeError(f"Not support device type: {self.device_config.device}") ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") ENV_LOCAL_WORLD_SIZE = os.getenv( "LOCAL_WORLD_SIZE", str(self.parallel_config.world_size) ) os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE os.environ["LOCAL_RANK"] = str(self.local_rank) init_worker_distributed_environment( self.vllm_config, self.rank, self.distributed_init_method, self.local_rank, current_platform.dist_backend, ) # Set random seed. set_random_seed(self.model_config.seed) # Now take memory snapshot after NCCL is initialized gc.collect() torch.xpu.empty_cache() # take current memory snapshot self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) self.requested_memory = request_memory(init_snapshot, self.cache_config) logger.debug("worker init memory snapshot: %r", self.init_snapshot) logger.debug( "worker requested memory: %sGiB", format_gib(self.requested_memory) ) # Initialize workspace manager num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 init_workspace_manager(self.device, num_ubatches) # Construct the model runner self.model_runner = XPUModelRunner( # type: ignore self.vllm_config, self.device ) if self.rank == 0: # If usage stat is enabled, collect relevant info. report_usage_stats(self.vllm_config)