Sync from v0.13
This commit is contained in:
48
vllm/v1/worker/xpu_model_runner.py
Normal file
48
vllm/v1/worker/xpu_model_runner.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUModelRunner(GPUModelRunner):
|
||||
"""A model runner for XPU devices."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
with _torch_cuda_wrapper():
|
||||
super().__init__(vllm_config, device)
|
||||
# FIXME: To be verified.
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
self.num_sms = None
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
try:
|
||||
# replace cuda APIs with xpu APIs, this should work by default
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.default_stream = torch.xpu.current_stream
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
Reference in New Issue
Block a user