# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager from typing import TYPE_CHECKING import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.worker.gpu_model_runner import GPUModelRunner if TYPE_CHECKING: pass logger = init_logger(__name__) class XPUModelRunner(GPUModelRunner): """A model runner for XPU devices.""" def __init__( self, vllm_config: VllmConfig, device: torch.device, ): with _torch_cuda_wrapper(): super().__init__(vllm_config, device) # FIXME: To be verified. self.cascade_attn_enabled = False def _init_device_properties(self) -> None: self.num_sms = None def _sync_device(self) -> None: torch.xpu.synchronize() @contextmanager def _torch_cuda_wrapper(): try: # replace cuda APIs with xpu APIs, this should work by default torch.cuda.Stream = torch.xpu.Stream torch.cuda.default_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.stream = torch.xpu.stream yield finally: pass