# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager from typing import TYPE_CHECKING import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.torch_utils import supports_xpu_graph from vllm.v1.worker.gpu_model_runner import GPUModelRunner if TYPE_CHECKING: pass logger = init_logger(__name__) class XPUModelRunner(GPUModelRunner): """A model runner for XPU devices.""" def __init__( self, vllm_config: VllmConfig, device: torch.device, ): with _torch_cuda_wrapper(): super().__init__(vllm_config, device) # FIXME: To be verified. self.cascade_attn_enabled = False def _sync_device(self) -> None: torch.xpu.synchronize() @contextmanager def _torch_cuda_wrapper(): try: # replace cuda APIs with xpu APIs, this should work by default torch.cuda.Stream = torch.xpu.Stream torch.cuda.default_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.stream = torch.xpu.stream torch.cuda.mem_get_info = torch.xpu.mem_get_info torch.cuda.synchronize = torch.xpu.synchronize if supports_xpu_graph(): torch.cuda.graph = torch.xpu.graph torch.cuda.CUDAGraph = torch.xpu.XPUGraph torch.cuda.empty_cache = torch.xpu.empty_cache yield finally: pass