Files
bi_150-vllm/vllm/v1/worker/xpu_model_runner.py

53 lines
1.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.torch_utils import supports_xpu_graph
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
pass
logger = init_logger(__name__)
class XPUModelRunner(GPUModelRunner):
"""A model runner for XPU devices."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
with _torch_cuda_wrapper():
super().__init__(vllm_config, device)
# FIXME: To be verified.
self.cascade_attn_enabled = False
def _sync_device(self) -> None:
torch.xpu.synchronize()
@contextmanager
def _torch_cuda_wrapper():
try:
# replace cuda APIs with xpu APIs, this should work by default
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream
torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch.cuda.synchronize = torch.xpu.synchronize
if supports_xpu_graph():
torch.cuda.graph = torch.xpu.graph
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
torch.cuda.empty_cache = torch.xpu.empty_cache
yield
finally:
pass