init
This commit is contained in:
263
vllm/platforms/__init__.py
Normal file
263
vllm/platforms/__init__.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import traceback
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.plugins import load_plugins_by_group
|
||||
from vllm.utils import resolve_obj_by_qualname, supports_xccl
|
||||
|
||||
from .interface import _Backend # noqa: F401
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def vllm_version_matches_substr(substr: str) -> bool:
|
||||
"""
|
||||
Check to see if the vLLM version matches a substring.
|
||||
"""
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
try:
|
||||
vllm_version = version("vllm")
|
||||
except PackageNotFoundError as e:
|
||||
logger.warning(
|
||||
"The vLLM package was not found, so its version could not be "
|
||||
"inspected. This may cause platform detection to fail.")
|
||||
raise e
|
||||
return substr in vllm_version
|
||||
|
||||
|
||||
def tpu_platform_plugin() -> Optional[str]:
|
||||
logger.debug("Checking if TPU platform is available.")
|
||||
|
||||
# Check for Pathways TPU proxy
|
||||
if envs.VLLM_TPU_USING_PATHWAYS:
|
||||
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
|
||||
return "tpu_commons.platforms.tpu_jax.TpuPlatform"
|
||||
|
||||
# Check for libtpu installation
|
||||
try:
|
||||
# While it's technically possible to install libtpu on a
|
||||
# non-TPU machine, this is a very uncommon scenario. Therefore,
|
||||
# we assume that libtpu is installed only if the machine
|
||||
# has TPUs.
|
||||
|
||||
import libtpu # noqa: F401
|
||||
logger.debug("Confirmed TPU platform is available.")
|
||||
return "vllm.platforms.tpu.TpuPlatform"
|
||||
except Exception as e:
|
||||
logger.debug("TPU platform is not available because: %s", str(e))
|
||||
return None
|
||||
|
||||
|
||||
def cuda_platform_plugin() -> Optional[str]:
|
||||
is_cuda = False
|
||||
logger.debug("Checking if CUDA platform is available.")
|
||||
try:
|
||||
from vllm.utils import import_pynvml
|
||||
pynvml = import_pynvml()
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
# NOTE: Edge case: vllm cpu build on a GPU machine.
|
||||
# Third-party pynvml can be imported in cpu build,
|
||||
# we need to check if vllm is built with cpu too.
|
||||
# Otherwise, vllm will always activate cuda plugin
|
||||
# on a GPU machine, even if in a cpu build.
|
||||
is_cuda = (pynvml.nvmlDeviceGetCount() > 0
|
||||
and not vllm_version_matches_substr("cpu"))
|
||||
if pynvml.nvmlDeviceGetCount() <= 0:
|
||||
logger.debug(
|
||||
"CUDA platform is not available because no GPU is found.")
|
||||
if vllm_version_matches_substr("cpu"):
|
||||
logger.debug("CUDA platform is not available because"
|
||||
" vLLM is built with CPU.")
|
||||
if is_cuda:
|
||||
logger.debug("Confirmed CUDA platform is available.")
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
except Exception as e:
|
||||
logger.debug("Exception happens when checking CUDA platform: %s",
|
||||
str(e))
|
||||
if "nvml" not in e.__class__.__name__.lower():
|
||||
# If the error is not related to NVML, re-raise it.
|
||||
raise e
|
||||
|
||||
# CUDA is supported on Jetson, but NVML may not be.
|
||||
import os
|
||||
|
||||
def cuda_is_jetson() -> bool:
|
||||
return os.path.isfile("/etc/nv_tegra_release") \
|
||||
or os.path.exists("/sys/class/tegra-firmware")
|
||||
|
||||
if cuda_is_jetson():
|
||||
logger.debug("Confirmed CUDA platform is available on Jetson.")
|
||||
is_cuda = True
|
||||
else:
|
||||
logger.debug("CUDA platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None
|
||||
|
||||
|
||||
def rocm_platform_plugin() -> Optional[str]:
|
||||
is_rocm = False
|
||||
logger.debug("Checking if ROCm platform is available.")
|
||||
try:
|
||||
import amdsmi
|
||||
amdsmi.amdsmi_init()
|
||||
try:
|
||||
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
|
||||
is_rocm = True
|
||||
logger.debug("Confirmed ROCm platform is available.")
|
||||
else:
|
||||
logger.debug("ROCm platform is not available because"
|
||||
" no GPU is found.")
|
||||
finally:
|
||||
amdsmi.amdsmi_shut_down()
|
||||
except Exception as e:
|
||||
logger.debug("ROCm platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
|
||||
|
||||
|
||||
def xpu_platform_plugin() -> Optional[str]:
|
||||
is_xpu = False
|
||||
logger.debug("Checking if XPU platform is available.")
|
||||
try:
|
||||
# installed IPEX if the machine has XPUs.
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import torch
|
||||
if supports_xccl():
|
||||
dist_backend = "xccl"
|
||||
else:
|
||||
dist_backend = "ccl"
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
|
||||
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
is_xpu = True
|
||||
from vllm.platforms.xpu import XPUPlatform
|
||||
XPUPlatform.dist_backend = dist_backend
|
||||
logger.debug("Confirmed %s backend is available.",
|
||||
XPUPlatform.dist_backend)
|
||||
logger.debug("Confirmed XPU platform is available.")
|
||||
except Exception as e:
|
||||
logger.debug("XPU platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
|
||||
|
||||
|
||||
def cpu_platform_plugin() -> Optional[str]:
|
||||
is_cpu = False
|
||||
logger.debug("Checking if CPU platform is available.")
|
||||
try:
|
||||
is_cpu = vllm_version_matches_substr("cpu")
|
||||
if is_cpu:
|
||||
logger.debug("Confirmed CPU platform is available because"
|
||||
" vLLM is built with CPU.")
|
||||
if not is_cpu:
|
||||
import sys
|
||||
is_cpu = sys.platform.startswith("darwin")
|
||||
if is_cpu:
|
||||
logger.debug("Confirmed CPU platform is available"
|
||||
" because the machine is MacOS.")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("CPU platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
|
||||
|
||||
|
||||
builtin_platform_plugins = {
|
||||
'tpu': tpu_platform_plugin,
|
||||
'cuda': cuda_platform_plugin,
|
||||
'rocm': rocm_platform_plugin,
|
||||
'xpu': xpu_platform_plugin,
|
||||
'cpu': cpu_platform_plugin,
|
||||
}
|
||||
|
||||
|
||||
def resolve_current_platform_cls_qualname() -> str:
|
||||
platform_plugins = load_plugins_by_group('vllm.platform_plugins')
|
||||
|
||||
activated_plugins = []
|
||||
|
||||
for name, func in chain(builtin_platform_plugins.items(),
|
||||
platform_plugins.items()):
|
||||
try:
|
||||
assert callable(func)
|
||||
platform_cls_qualname = func()
|
||||
if platform_cls_qualname is not None:
|
||||
activated_plugins.append(name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
activated_builtin_plugins = list(
|
||||
set(activated_plugins) & set(builtin_platform_plugins.keys()))
|
||||
activated_oot_plugins = list(
|
||||
set(activated_plugins) & set(platform_plugins.keys()))
|
||||
|
||||
if len(activated_oot_plugins) >= 2:
|
||||
raise RuntimeError(
|
||||
"Only one platform plugin can be activated, but got: "
|
||||
f"{activated_oot_plugins}")
|
||||
elif len(activated_oot_plugins) == 1:
|
||||
platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
|
||||
logger.info("Platform plugin %s is activated",
|
||||
activated_oot_plugins[0])
|
||||
elif len(activated_builtin_plugins) >= 2:
|
||||
raise RuntimeError(
|
||||
"Only one platform plugin can be activated, but got: "
|
||||
f"{activated_builtin_plugins}")
|
||||
elif len(activated_builtin_plugins) == 1:
|
||||
platform_cls_qualname = builtin_platform_plugins[
|
||||
activated_builtin_plugins[0]]()
|
||||
logger.info("Automatically detected platform %s.",
|
||||
activated_builtin_plugins[0])
|
||||
else:
|
||||
platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform"
|
||||
logger.info(
|
||||
"No platform detected, vLLM is running on UnspecifiedPlatform")
|
||||
return platform_cls_qualname
|
||||
|
||||
|
||||
_current_platform = None
|
||||
_init_trace: str = ''
|
||||
|
||||
if TYPE_CHECKING:
|
||||
current_platform: Platform
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == 'current_platform':
|
||||
# lazy init current_platform.
|
||||
# 1. out-of-tree platform plugins need `from vllm.platforms import
|
||||
# Platform` so that they can inherit `Platform` class. Therefore,
|
||||
# we cannot resolve `current_platform` during the import of
|
||||
# `vllm.platforms`.
|
||||
# 2. when users use out-of-tree platform plugins, they might run
|
||||
# `import vllm`, some vllm internal code might access
|
||||
# `current_platform` during the import, and we need to make sure
|
||||
# `current_platform` is only resolved after the plugins are loaded
|
||||
# (we have tests for this, if any developer violate this, they will
|
||||
# see the test failures).
|
||||
global _current_platform
|
||||
if _current_platform is None:
|
||||
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
||||
_current_platform = resolve_obj_by_qualname(
|
||||
platform_cls_qualname)()
|
||||
global _init_trace
|
||||
_init_trace = "".join(traceback.format_stack())
|
||||
return _current_platform
|
||||
elif name in globals():
|
||||
return globals()[name]
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"No attribute named '{name}' exists in {__name__}.")
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum',
|
||||
"_init_trace"
|
||||
]
|
||||
BIN
vllm/platforms/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/platforms/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/platforms/__pycache__/interface.cpython-312.pyc
Normal file
BIN
vllm/platforms/__pycache__/interface.cpython-312.pyc
Normal file
Binary file not shown.
340
vllm/platforms/cpu.py
Normal file
340
vllm/platforms/cpu.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
def get_max_threads(pid=0):
|
||||
if hasattr(os, 'sched_getaffinity'):
|
||||
return len(os.sched_getaffinity(pid))
|
||||
elif platform.system() == 'Darwin':
|
||||
return os.cpu_count()
|
||||
else:
|
||||
raise NotImplementedError("Unsupported OS")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogicalCPUInfo:
|
||||
id: int = -1
|
||||
physical_core: int = -1
|
||||
numa_node: int = -1
|
||||
|
||||
@classmethod
|
||||
def _int(cls, value: str) -> int:
|
||||
try:
|
||||
int_value = int(value)
|
||||
except Exception:
|
||||
int_value = -1
|
||||
return int_value
|
||||
|
||||
@staticmethod
|
||||
def json_decoder(obj_dict: dict):
|
||||
id = obj_dict.get("cpu")
|
||||
physical_core = obj_dict.get("core")
|
||||
numa_node = obj_dict.get("node")
|
||||
|
||||
if not (id is None or physical_core is None or numa_node is None):
|
||||
return LogicalCPUInfo(
|
||||
id=LogicalCPUInfo._int(id),
|
||||
physical_core=LogicalCPUInfo._int(physical_core),
|
||||
numa_node=LogicalCPUInfo._int(numa_node))
|
||||
else:
|
||||
return obj_dict
|
||||
|
||||
|
||||
class CpuPlatform(Platform):
|
||||
_enum = PlatformEnum.CPU
|
||||
device_name: str = "cpu"
|
||||
device_type: str = "cpu"
|
||||
dispatch_key: str = "CPU"
|
||||
dist_backend: str = "gloo"
|
||||
device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
|
||||
return [torch.bfloat16, torch.float32]
|
||||
elif (self.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
and sys.platform.startswith("darwin")):
|
||||
if (subprocess.check_output(
|
||||
["sysctl -n hw.optional.arm.FEAT_BF16"],
|
||||
shell=True).strip() == b"1"):
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
return [torch.float16, torch.float32]
|
||||
# x86/aarch64 CPU has supported both bf16 and fp16 natively.
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool, use_sparse: bool) -> str:
|
||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on CPU.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
if not use_v1:
|
||||
raise ValueError("CPU backend only supports V1.")
|
||||
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
import vllm.envs as envs
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||
if kv_cache_space is None:
|
||||
kv_cache_space = 4 * GiB_bytes # type: ignore
|
||||
logger.warning_once(
|
||||
"Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
|
||||
"for CPU backend is not set, using 4 by default.")
|
||||
else:
|
||||
kv_cache_space *= GiB_bytes
|
||||
|
||||
return kv_cache_space
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if model_config is not None:
|
||||
model_config.disable_cascade_attn = True
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
ipex_available = find_spec("intel_extension_for_pytorch") is not None
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 128 if ipex_available else 16
|
||||
|
||||
if not ipex_available and cache_config.block_size != 16:
|
||||
raise RuntimeError(
|
||||
f"--block-size={cache_config.block_size} requires"
|
||||
" intel_extension_for_pytorch")
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if ((scheduler_config.chunked_prefill_enabled
|
||||
or cache_config.enable_prefix_caching)
|
||||
and cache_config.cache_dtype != "auto"):
|
||||
raise RuntimeError("Chunked-prefill and prefix-cache on the CPU "
|
||||
"backend is not compatible with FP8 KV cache.")
|
||||
|
||||
if cache_config.cache_dtype == "fp8_e4m3":
|
||||
cache_config.cache_dtype = "fp8_e5m2"
|
||||
logger.warning(
|
||||
"CPU backend doesn't support fp8_e4m3 KV cache type, "
|
||||
"cast to fp8_e5m2.")
|
||||
|
||||
if (cache_config.cache_dtype != "auto" and model_config is not None
|
||||
and model_config.dtype == torch.half):
|
||||
logger.warning("FP8 KV cache on the CPU backend only does not"
|
||||
" support fp16 for now, cast to bf16.")
|
||||
model_config.dtype = torch.bfloat16
|
||||
|
||||
cache_config.cpu_kvcache_space_bytes = \
|
||||
CpuPlatform.get_device_total_memory()
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if (parallel_config.world_size > 1
|
||||
and parallel_config.distributed_executor_backend is not None
|
||||
and parallel_config.distributed_executor_backend != "mp"):
|
||||
logger.warning(("%s is not supported on CPU, fallback to mp "
|
||||
"distributed executor backend."),
|
||||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "mp"
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
|
||||
# Disable DBO
|
||||
if parallel_config.enable_dbo:
|
||||
logger.warning(
|
||||
"Dual-Batch Overlap is not supported on CPU, disabled.")
|
||||
parallel_config.enable_dbo = False
|
||||
|
||||
# Note: workaround for v1 gpu_model_runner
|
||||
from vllm.config import CompilationLevel
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = []
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
|
||||
# Note: vLLM V1 is using PIECEWISE level compilation, which will
|
||||
# take time to compile kernels just-in-time with the inductor
|
||||
# backend. For CPU CI tests, most of them are executed fast and
|
||||
# compilations consume too much time, even with torch compile
|
||||
# cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
|
||||
# and just execute model with dynamo + eager mode to save time.
|
||||
# VLLM_CPU_CI_ENV is only used as an internal variable.
|
||||
if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
|
||||
backend = "eager"
|
||||
else:
|
||||
backend = "inductor"
|
||||
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
compilation_config.backend = backend
|
||||
compilation_config.inductor_compile_config.update({
|
||||
"dce":
|
||||
True,
|
||||
"size_asserts":
|
||||
False,
|
||||
"nan_asserts":
|
||||
False,
|
||||
"epilogue_fusion":
|
||||
True,
|
||||
})
|
||||
if compilation_config.use_inductor:
|
||||
compilation_config.custom_ops = ["none"]
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
assert vllm_config.device_config.device_type == "cpu"
|
||||
|
||||
#
|
||||
# Environment variables for CPU executor
|
||||
#
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
# Note: to avoid the error 'nthreads cannot be larger than environment
|
||||
# variable "NUMEXPR_MAX_THREADS" (64)'.
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
|
||||
|
||||
# Set default threads num for OpenMP parallel
|
||||
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
# Intel OpenMP setting
|
||||
ld_prealod_str = os.getenv("LD_PRELOAD", "")
|
||||
if "libiomp5.so" in ld_prealod_str:
|
||||
# The time(milliseconds) that a thread should wait after
|
||||
# completing the execution of a parallel region, before sleeping.
|
||||
os.environ['KMP_BLOCKTIME'] = "1"
|
||||
# Prevents the CPU to run into low performance state
|
||||
os.environ['KMP_TPAUSE'] = "0"
|
||||
# Provides fine granularity parallelism
|
||||
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
|
||||
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
|
||||
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
|
||||
|
||||
# To hint IPEX uses shared memory based AllReduce
|
||||
os.environ["LOCAL_WORLD_SIZE"] = str(
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
|
||||
if model_config is not None and model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def get_allowed_cpu_core_node_list(
|
||||
cls) -> tuple[list[int], list[LogicalCPUInfo]]:
|
||||
assert platform.system() == "Linux"
|
||||
|
||||
# Init LogicalCPUInfo from lscpu
|
||||
lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE",
|
||||
shell=True,
|
||||
text=True)
|
||||
logical_cpu_list: list[LogicalCPUInfo] = json.loads(
|
||||
lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus']
|
||||
|
||||
# Filter CPUs with invalid attributes
|
||||
logical_cpu_list = [
|
||||
x for x in logical_cpu_list
|
||||
if -1 not in (x.id, x.physical_core, x.numa_node)
|
||||
]
|
||||
|
||||
# Filter allowed CPUs
|
||||
allowed_cpu_id_list = os.sched_getaffinity(0)
|
||||
logical_cpu_list = [
|
||||
x for x in logical_cpu_list if x.id in allowed_cpu_id_list
|
||||
]
|
||||
|
||||
# Get allowed NUMA nodes
|
||||
allowed_numa_nodes = set()
|
||||
for x in logical_cpu_list:
|
||||
allowed_numa_nodes.add(x.numa_node) # type: ignore
|
||||
allowed_numa_nodes_list = sorted(allowed_numa_nodes)
|
||||
|
||||
env_key = CpuPlatform.device_control_env_var
|
||||
if (env_key in os.environ and os.environ[env_key] != ""):
|
||||
visible_nodes = [int(s) for s in os.environ[env_key].split(',')]
|
||||
allowed_numa_nodes_list = [
|
||||
x for x in visible_nodes if x in allowed_cpu_id_list
|
||||
]
|
||||
|
||||
return allowed_numa_nodes_list, logical_cpu_list
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
logger.warning("Pin memory is not supported on CPU.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
"""
|
||||
Get device specific communicator class for distributed communication.
|
||||
"""
|
||||
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_structured_output(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
668
vllm/platforms/cuda.py
Normal file
668
vllm/platforms/cuda.py
Normal file
@@ -0,0 +1,668 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Code inside this file can safely assume cuda platform, e.g. importing
|
||||
pynvml. However, it should not initialize cuda context.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from functools import cache, wraps
|
||||
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import is_nccl_available
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless, import_pynvml
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
pynvml = import_pynvml()
|
||||
|
||||
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
||||
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CudaPlatformBase(Platform):
|
||||
_enum = PlatformEnum.CUDA
|
||||
device_name: str = "cuda"
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
ray_device_key: str = "GPU"
|
||||
dist_backend: str = "nccl"
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
if self.has_device_capability(80):
|
||||
# Ampere and Hopper or later NVIDIA GPUs.
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
if self.has_device_capability(60):
|
||||
# Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
|
||||
return [torch.float16, torch.float32]
|
||||
# Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
|
||||
# though vLLM doesn't support these GPUs.
|
||||
return [torch.float32]
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cuda.set_device(device)
|
||||
# With this trick we can force the device to be set eagerly
|
||||
# see https://github.com/pytorch/pytorch/issues/155668
|
||||
# for why and when it is needed
|
||||
_ = torch.zeros(1, device=device)
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls,
|
||||
device_id: int = 0
|
||||
) -> Optional[DeviceCapability]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_fully_connected(cls, device_ids: list[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def log_warnings(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if vllm_config.speculative_config:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not supported on vLLM V0.")
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
# TODO(lucas): handle this more gracefully
|
||||
# Note: model_config may be None during testing
|
||||
if model_config is not None and model_config.use_mla:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
|
||||
# then we default to FlashMLA backend for non-blackwell GPUs,
|
||||
# else we default to CutlassMLA. For each case, we force the
|
||||
# required block_size.
|
||||
use_flashmla = False
|
||||
use_cutlass_mla = False
|
||||
use_flashinfer_mla = False
|
||||
|
||||
if envs.VLLM_ATTENTION_BACKEND is None:
|
||||
# Default case
|
||||
if cls.is_device_capability(100):
|
||||
# Blackwell => Force CutlassMLA.
|
||||
use_cutlass_mla = True
|
||||
# TODO: This does not work, because the
|
||||
# global_force_attn_backend_context_manager is not set.
|
||||
# See vllm/attention/selector.py:_cached_get_attn_backend
|
||||
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
|
||||
else:
|
||||
# Not Blackwell
|
||||
use_flashmla = True
|
||||
else:
|
||||
# Forced case
|
||||
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
|
||||
use_cutlass_mla = (
|
||||
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
|
||||
use_flashinfer_mla = (
|
||||
envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA")
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
if use_flashmla and is_flashmla_supported()[0] \
|
||||
and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||
|
||||
if use_cutlass_mla and cache_config.block_size != 128:
|
||||
cache_config.block_size = 128
|
||||
logger.info("Forcing kv cache block size to 128 for "
|
||||
"CUTLASS_MLA backend.")
|
||||
|
||||
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashInferMLA "
|
||||
"backend.")
|
||||
|
||||
# TODO(Chen): remove this hacky code
|
||||
if use_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse "
|
||||
"backend.")
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||
and parallel_config.data_parallel_size > 1
|
||||
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
|
||||
# TODO: Piecewise Cuda graph might be enabled
|
||||
# if torch compile cache key issue fixed
|
||||
# See https://github.com/vllm-project/vllm/pull/25093
|
||||
logger.info(
|
||||
"WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
|
||||
"kernels are optimized for prefill and are incompatible with "
|
||||
"CUDA Graphs. "
|
||||
"In order to use CUDA Graphs for decode-optimized workloads, "
|
||||
"set VLLM_ALL2ALL_BACKEND to another option, such as "
|
||||
"deepep_low_latency, pplx, or allgather_reducescatter.")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> _Backend:
|
||||
|
||||
# For Blackwell GPUs, force TORCH_SDPA for now.
|
||||
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
||||
if cls.has_device_capability(100):
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
return _Backend.XFORMERS
|
||||
|
||||
if cls.has_device_capability(80):
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
is_default_fa_supported = is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False)
|
||||
if is_default_fa_supported:
|
||||
return _Backend.FLASH_ATTN
|
||||
else:
|
||||
# Fallback to XFORMERS
|
||||
return _Backend.XFORMERS
|
||||
else:
|
||||
# Fallback for Volta/Turing GPUs or FA not supported
|
||||
return _Backend.XFORMERS
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"MLA attention backends require the V1 engine. "
|
||||
"Set VLLM_USE_V1=1 to enable them.")
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
|
||||
if use_sparse:
|
||||
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla.flashmla_sparse."
|
||||
"FlashMLASparseBackend")
|
||||
|
||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||
selected_backend is None and cls.is_device_capability(100)
|
||||
and block_size == 128)
|
||||
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
|
||||
selected_backend is None and cls.is_device_capability(100)
|
||||
and block_size in [32, 64])
|
||||
use_flashmla = selected_backend == _Backend.FLASHMLA or (
|
||||
selected_backend is None and is_flashmla_supported()[0])
|
||||
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
|
||||
selected_backend is None and flash_attn_supports_mla())
|
||||
use_triton = selected_backend == _Backend.TRITON_MLA or (
|
||||
selected_backend is None)
|
||||
|
||||
if use_cutlassmla:
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"cutlass_mla.CutlassMLABackend")
|
||||
if use_flashinfermla:
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
set_kv_cache_layout("HND")
|
||||
logger.info_once("Using FlashInfer MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashinfer_mla.FlashInferMLABackend")
|
||||
if use_flashmla:
|
||||
if block_size != 64:
|
||||
logger.warning(
|
||||
"FlashMLA backend is not supported for block size %d"
|
||||
" (currently only supports block size 64).",
|
||||
block_size)
|
||||
else:
|
||||
logger.info_once("Using FlashMLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashmla.FlashMLABackend")
|
||||
if use_flashattn:
|
||||
logger.info_once(
|
||||
"Using FlashAttention MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashattn_mla.FlashAttnMLABackend")
|
||||
if use_triton:
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
if use_v1:
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
|
||||
|
||||
use_fp8_kv_cache = (kv_cache_dtype is not None
|
||||
and kv_cache_dtype.startswith("fp8"))
|
||||
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||
if cls.has_device_capability(100):
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
set_kv_cache_layout("HND")
|
||||
return FLASHINFER_V1
|
||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
elif selected_backend == _Backend.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
elif selected_backend == _Backend.TREE_ATTN:
|
||||
logger.info_once("Using Tree Attention backend on V1 engine.")
|
||||
return TREE_ATTN_V1
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info_once("Using XFormers backend on V1 engine.")
|
||||
return XFORMERS_V1
|
||||
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
|
||||
# Default backends for V1 engine
|
||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||
if cls.is_device_capability(100):
|
||||
if is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASHINFER_V1, head_size, dtype):
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend with HND KV cache layout on "
|
||||
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
return FLASHINFER_V1
|
||||
|
||||
if not is_default_backend_supported.can_import:
|
||||
logger.warning_once(
|
||||
"FlashInfer failed to import for V1 engine on "
|
||||
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
||||
"install FlashInfer for better performance.")
|
||||
|
||||
# FlashAttention is the default for SM 8.0+ GPUs
|
||||
if cls.has_device_capability(80):
|
||||
if (has_sink or
|
||||
use_fp8_kv_cache) and not cls.is_device_capability(90):
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN
|
||||
elif is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype,
|
||||
allow_import_error=False):
|
||||
logger.info_once("Using Flash Attention backend on "
|
||||
"V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
|
||||
# FlexAttention is the default for older GPUs
|
||||
else:
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
assert not is_default_backend_supported
|
||||
|
||||
use_flex_attention_reason = {}
|
||||
if not is_default_backend_supported.head_size:
|
||||
use_flex_attention_reason["head_size"] = head_size
|
||||
if not is_default_backend_supported.dtype:
|
||||
use_flex_attention_reason["dtype"] = dtype
|
||||
|
||||
logger.info_once(
|
||||
"Using FlexAttention backend for %s on V1 engine.",
|
||||
", ".join(f"{k}={v}"
|
||||
for k, v in use_flex_attention_reason.items()),
|
||||
)
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend.")
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
return cls.has_device_capability(89)
|
||||
|
||||
@classmethod
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
assert is_nccl_available()
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||
attention_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
|
||||
supported = False
|
||||
if model_config is not None and model_config.use_mla:
|
||||
# Default to CutlassMLA for blackwell,
|
||||
# FlashMLA otherwise
|
||||
if attention_backend is None:
|
||||
if cls.is_device_capability(100):
|
||||
attention_backend = "CUTLASS_MLA"
|
||||
else:
|
||||
attention_backend = "FLASHMLA"
|
||||
|
||||
# Only FlashMLA and CUTLASS_MLA support fp8
|
||||
if attention_backend in [
|
||||
"FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"
|
||||
]:
|
||||
supported = True
|
||||
else:
|
||||
supported = (not fp8_attention)
|
||||
else:
|
||||
# Default to FlashAttention
|
||||
if attention_backend is None:
|
||||
attention_backend = "FLASH_ATTN"
|
||||
|
||||
# All Blackwell backends support fp8
|
||||
if cls.is_device_capability(100):
|
||||
supported = True
|
||||
elif attention_backend == "FLASH_ATTN":
|
||||
if fp8_attention:
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_fp8)
|
||||
supported = flash_attn_supports_fp8()
|
||||
else:
|
||||
supported = True
|
||||
elif attention_backend == "FLASHINFER":
|
||||
supported = True
|
||||
elif attention_backend == "TRITON_ATTN":
|
||||
supported = cls.supports_fp8()
|
||||
return supported
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not cls.has_device_capability(80):
|
||||
capability = cls.get_device_capability()
|
||||
gpu_name = cls.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs "
|
||||
"with compute capability of at least 8.0. "
|
||||
f"Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
# the major benefit of using NVML is that it will not initialize CUDA
|
||||
class NvmlCudaPlatform(CudaPlatformBase):
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
@with_nvml_context
|
||||
def get_device_capability(cls,
|
||||
device_id: int = 0
|
||||
) -> Optional[DeviceCapability]:
|
||||
try:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def has_device_capability(
|
||||
cls,
|
||||
capability: Union[tuple[int, int], int],
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
try:
|
||||
return super().has_device_capability(capability, device_id)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
return cls._get_physical_device_name(physical_device_id)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def get_device_uuid(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
return pynvml.nvmlDeviceGetUUID(handle)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [
|
||||
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
|
||||
]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle,
|
||||
peer_handle,
|
||||
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if"
|
||||
" your machine has no NVLink equipped.")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _get_physical_device_name(cls, device_id: int = 0) -> str:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetName(handle)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def log_warnings(cls):
|
||||
device_ids: int = pynvml.nvmlDeviceGetCount()
|
||||
if device_ids > 1:
|
||||
device_names = [
|
||||
cls._get_physical_device_name(i) for i in range(device_ids)
|
||||
]
|
||||
if (len(set(device_names)) > 1
|
||||
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
|
||||
logger.warning(
|
||||
"Detected different devices in the system: %s. Please"
|
||||
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
||||
"avoid unexpected behavior.",
|
||||
", ".join(device_names),
|
||||
)
|
||||
|
||||
|
||||
class NonNvmlCudaPlatform(CudaPlatformBase):
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
logger.exception(
|
||||
"NVLink detection not possible, as context support was"
|
||||
" not found. Assuming no NVLink available.")
|
||||
return False
|
||||
|
||||
|
||||
# Autodetect either NVML-enabled or non-NVML platform
|
||||
# based on whether NVML is available.
|
||||
nvml_available = False
|
||||
try:
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
nvml_available = True
|
||||
except Exception:
|
||||
# On Jetson, NVML is not supported.
|
||||
nvml_available = False
|
||||
finally:
|
||||
if nvml_available:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
|
||||
|
||||
CudaPlatform.log_warnings()
|
||||
620
vllm/platforms/interface.py
Normal file
620
vllm/platforms/interface.py
Normal file
@@ -0,0 +1,620 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from platform import uname
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
LoRARequest = None
|
||||
PoolingParams = None
|
||||
SamplingParams = None
|
||||
FlexibleArgumentParser = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def in_wsl() -> bool:
|
||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||
return "microsoft" in " ".join(uname()).lower()
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
TRITON_ATTN = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
ROCM_AITER_MLA = enum.auto() # Supported by V1
|
||||
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
FLASHINFER_MLA = enum.auto()
|
||||
TRITON_MLA = enum.auto() # Supported by V1
|
||||
CUTLASS_MLA = enum.auto()
|
||||
FLASHMLA = enum.auto() # Supported by V1
|
||||
FLASH_ATTN_MLA = enum.auto() # Supported by V1
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
FLEX_ATTENTION = enum.auto()
|
||||
TREE_ATTN = enum.auto()
|
||||
ROCM_ATTN = enum.auto()
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
ROCM = enum.auto()
|
||||
TPU = enum.auto()
|
||||
XPU = enum.auto()
|
||||
CPU = enum.auto()
|
||||
OOT = enum.auto()
|
||||
UNSPECIFIED = enum.auto()
|
||||
|
||||
|
||||
class CpuArchEnum(enum.Enum):
|
||||
X86 = enum.auto()
|
||||
ARM = enum.auto()
|
||||
POWERPC = enum.auto()
|
||||
S390X = enum.auto()
|
||||
RISCV = enum.auto()
|
||||
OTHER = enum.auto()
|
||||
UNKNOWN = enum.auto()
|
||||
|
||||
|
||||
class DeviceCapability(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
|
||||
def as_version_str(self) -> str:
|
||||
return f"{self.major}.{self.minor}"
|
||||
|
||||
def to_int(self) -> int:
|
||||
"""
|
||||
Express device capability as an integer `<major><minor>`.
|
||||
|
||||
It is assumed that the minor version is always a single digit.
|
||||
"""
|
||||
assert 0 <= self.minor < 10
|
||||
return self.major * 10 + self.minor
|
||||
|
||||
|
||||
class Platform:
|
||||
_enum: PlatformEnum
|
||||
device_name: str
|
||||
device_type: str
|
||||
|
||||
# available dispatch keys:
|
||||
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
# available ray device keys:
|
||||
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
|
||||
# empty string means the device does not support ray
|
||||
ray_device_key: str = ""
|
||||
|
||||
# platform-agnostic way to specify the device control environment variable,
|
||||
# .e.g. CUDA_VISIBLE_DEVICES for CUDA.
|
||||
# hint: search for "get_visible_accelerator_ids_env_var" in
|
||||
# https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
|
||||
device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
|
||||
|
||||
# The torch.compile backend for compiling simple and
|
||||
# standalone functions. The default value is "inductor" to keep
|
||||
# the same behavior as PyTorch.
|
||||
# NOTE: for the forward part of the model, vLLM has another separate
|
||||
# compilation strategy.
|
||||
simple_compile_backend: str = "inductor"
|
||||
|
||||
# The backend used for distributed communication.
|
||||
dist_backend: str = ""
|
||||
|
||||
supported_quantization: list[str] = []
|
||||
|
||||
additional_env_vars: list[str] = []
|
||||
|
||||
_global_graph_pool: Optional[Any] = None
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
"""Returns the supported dtypes for the current platform."""
|
||||
# Be careful with the order of the dtypes. The first dtype will
|
||||
# be used as the default dtype fallback for the current platform,
|
||||
# when encountering unsupported dtypes in "auto" dtype.
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def is_cuda(self) -> bool:
|
||||
return self._enum == PlatformEnum.CUDA
|
||||
|
||||
def is_rocm(self) -> bool:
|
||||
return self._enum == PlatformEnum.ROCM
|
||||
|
||||
def is_tpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.TPU
|
||||
|
||||
def is_xpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.XPU
|
||||
|
||||
def is_cpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.CPU
|
||||
|
||||
def is_out_of_tree(self) -> bool:
|
||||
return self._enum == PlatformEnum.OOT
|
||||
|
||||
def get_max_output_tokens(self, prompt_len: int) -> int:
|
||||
return sys.maxsize
|
||||
|
||||
def is_cuda_alike(self) -> bool:
|
||||
"""Stateless version of [torch.cuda.is_available][]."""
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
def is_sleep_mode_available(self) -> bool:
|
||||
return self._enum == PlatformEnum.CUDA
|
||||
|
||||
@classmethod
|
||||
def device_id_to_physical_device_id(cls, device_id: int):
|
||||
# Treat empty device control env var as unset. This is a valid
|
||||
# configuration in Ray setups where the engine is launched in
|
||||
# a CPU-only placement group located on a GPU node.
|
||||
if cls.device_control_env_var in os.environ and os.environ[
|
||||
cls.device_control_env_var] != "":
|
||||
device_ids = os.environ[cls.device_control_env_var].split(",")
|
||||
physical_device_id = device_ids[device_id]
|
||||
return int(physical_device_id)
|
||||
else:
|
||||
return device_id
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> _Backend:
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool, use_sparse: bool) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
device_id: int = 0,
|
||||
) -> Optional[DeviceCapability]:
|
||||
"""Stateless version of [torch.cuda.get_device_capability][]."""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def has_device_capability(
|
||||
cls,
|
||||
capability: Union[tuple[int, int], int],
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Test whether this platform is compatible with a device capability.
|
||||
|
||||
The `capability` argument can either be:
|
||||
|
||||
- A tuple `(major, minor)`.
|
||||
- An integer `<major><minor>`. (See
|
||||
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
|
||||
"""
|
||||
current_capability = cls.get_device_capability(device_id=device_id)
|
||||
if current_capability is None:
|
||||
return False
|
||||
|
||||
if isinstance(capability, tuple):
|
||||
return current_capability >= capability
|
||||
|
||||
return current_capability.to_int() >= capability
|
||||
|
||||
@classmethod
|
||||
def is_device_capability(
|
||||
cls,
|
||||
capability: Union[tuple[int, int], int],
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Test whether this platform has exactly the specified device capability.
|
||||
|
||||
The `capability` argument can either be:
|
||||
|
||||
- A tuple `(major, minor)`.
|
||||
- An integer `<major><minor>`. (See
|
||||
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
|
||||
"""
|
||||
current_capability = cls.get_device_capability(device_id=device_id)
|
||||
if current_capability is None:
|
||||
return False
|
||||
|
||||
if isinstance(capability, tuple):
|
||||
return current_capability == capability
|
||||
|
||||
return current_capability.to_int() == capability
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
"""Get the name of a device."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_uuid(cls, device_id: int = 0) -> str:
|
||||
"""Get the uuid of a device, e.g. the PCI bus ID."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
"""Get the total memory of a device in bytes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
"""A device-specific wrapper of `torch.inference_mode`.
|
||||
|
||||
This wrapper is recommended because some hardware backends such as TPU
|
||||
do not support `torch.inference_mode`. In such a case, they will fall
|
||||
back to `torch.no_grad` by overriding this method.
|
||||
"""
|
||||
return torch.inference_mode(mode=True)
|
||||
|
||||
@classmethod
|
||||
def seed_everything(cls, seed: Optional[int] = None) -> None:
|
||||
"""
|
||||
Set the seed of each random module.
|
||||
`torch.manual_seed` will set seed on all devices.
|
||||
|
||||
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
||||
"""
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def pre_register_and_update(cls,
|
||||
parser: Optional[FlexibleArgumentParser] = None
|
||||
) -> None:
|
||||
"""
|
||||
Do some pre-registration or update action for the current platform.
|
||||
|
||||
This function is called before global VllmConfig is initialized or cli
|
||||
arguments are parsed. It's used for out-of-tree platforms to register or
|
||||
update the configuration.
|
||||
|
||||
For example, the out-of-tree quantization config can be imported and
|
||||
registered here dynamically.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
"""
|
||||
Check and update the configuration for the current platform.
|
||||
|
||||
It can raise an exception if the configuration is not compatible with
|
||||
the current platform, or it can update the configuration to make it
|
||||
compatible with the current platform.
|
||||
|
||||
The config is passed by reference, so it can be modified in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
"""
|
||||
Verify whether the current platform supports the specified model
|
||||
architecture.
|
||||
|
||||
- This will raise an Error or Warning based on the model support on
|
||||
the current platform.
|
||||
- By default all models are considered supported.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def verify_quantization(cls, quant: str) -> None:
|
||||
"""
|
||||
Verify whether the quantization is supported by the current platform.
|
||||
"""
|
||||
if cls.supported_quantization and \
|
||||
quant not in cls.supported_quantization:
|
||||
raise ValueError(
|
||||
f"{quant} quantization is currently not supported in "
|
||||
f"{cls.device_name}.")
|
||||
|
||||
@classmethod
|
||||
def get_cpu_architecture(cls) -> CpuArchEnum:
|
||||
"""
|
||||
Determine the CPU architecture of the current system.
|
||||
Returns CpuArchEnum indicating the architecture type.
|
||||
"""
|
||||
machine = platform.machine().lower()
|
||||
|
||||
if machine in ("x86_64", "amd64", "i386", "i686"):
|
||||
return CpuArchEnum.X86
|
||||
elif machine.startswith("arm") or machine.startswith("aarch"):
|
||||
return CpuArchEnum.ARM
|
||||
elif machine.startswith("ppc"):
|
||||
return CpuArchEnum.POWERPC
|
||||
elif machine == "s390x":
|
||||
return CpuArchEnum.S390X
|
||||
elif machine.startswith("riscv"):
|
||||
return CpuArchEnum.RISCV
|
||||
|
||||
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
"""Checks whether pin memory is available on the current platform."""
|
||||
if in_wsl():
|
||||
# Pinning memory in WSL is not supported.
|
||||
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
||||
logger.warning("Using 'pin_memory=False' as WSL is detected. "
|
||||
"This may slow down the performance.")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
"""
|
||||
Return the memory usage in bytes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
"""
|
||||
Return the punica wrapper for current platform.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
|
||||
"""
|
||||
Return the platform specific values for (-inf, inf)
|
||||
"""
|
||||
return float("-inf"), float("inf")
|
||||
|
||||
@classmethod
|
||||
def can_update_inplace(cls) -> bool:
|
||||
"""
|
||||
Checks if the platform allows inplace memory updates
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lora_vocab_padding_size(cls) -> int:
|
||||
"""
|
||||
Returns how much padding the LoRA logits need for kernels
|
||||
"""
|
||||
return 256
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
"""
|
||||
Get device specific communicator class for distributed communication.
|
||||
"""
|
||||
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_mx(cls) -> bool:
|
||||
"""
|
||||
Returns whether the current platform supports MX types.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
"""
|
||||
Returns whether the current platform supports FP8 types.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
"""
|
||||
Returns whether the preferred FP8 type is FNUZ on the current platform.
|
||||
|
||||
There are two representations of FP8, OCP FP8 and FNUZ FP8.
|
||||
The OCP specification can be found at https://tinyurl.com/b7jvwpft.
|
||||
The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.
|
||||
|
||||
AMD's MI300 and MI325 have native hardware support for FNUZ. All other
|
||||
hardware has converged on the OCP FP8 standard.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
"""
|
||||
Returns the preferred FP8 type on the current platform.
|
||||
|
||||
See the documentation for is_fp8_fnuz for details.
|
||||
"""
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
@classmethod
|
||||
def use_all_gather(cls) -> bool:
|
||||
"""
|
||||
Whether to use allgather in LogitsProcessor to gather the logits.
|
||||
"""
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
parallel_config = get_current_vllm_config().parallel_config
|
||||
return (envs.VLLM_USE_V1
|
||||
or parallel_config.distributed_executor_backend
|
||||
== "external_launcher")
|
||||
|
||||
@classmethod
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
"""
|
||||
Returns if custom allreduce is supported on the current platform
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
"""
|
||||
Returns True if we register attention as one giant opaque custom op
|
||||
on the current platform
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_request(
|
||||
cls,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
processed_inputs: ProcessorInputs,
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
device = getattr(torch, self.device_type, None)
|
||||
if device is not None and hasattr(device, key):
|
||||
return getattr(device, key)
|
||||
else:
|
||||
logger.warning("Current platform %s does not have '%s'" \
|
||||
" attribute.", self.device_type, key)
|
||||
return None
|
||||
|
||||
def get_global_graph_pool(self) -> Any:
|
||||
"""
|
||||
Return the global graph pool for this platform.
|
||||
"""
|
||||
cls = self.__class__
|
||||
if cls._global_graph_pool is None:
|
||||
cls._global_graph_pool = self.graph_pool_handle()
|
||||
return cls._global_graph_pool
|
||||
|
||||
@classmethod
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
"""
|
||||
Returns the total number of compute units (CU) on single GPU.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
"""
|
||||
Get static graph wrapper class for static graph.
|
||||
"""
|
||||
return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Init platform-specific torch distributed process group.
|
||||
"""
|
||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Returns if the kv_cache_dtype is supported by the current platform.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
"""
|
||||
Check if the dtype is supported by the current platform.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
"""
|
||||
Returns if the hybrid kv cache is supported by the current platform.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
"""
|
||||
Returns if the graph mode is supported by the current platform.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def use_sync_weight_loader(cls) -> bool:
|
||||
"""
|
||||
Returns if the current platform needs to sync weight loader.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def make_synced_weight_loader(cls, original_weight_loader):
|
||||
"""
|
||||
Wrap the original weight loader to make it synced.
|
||||
"""
|
||||
if not cls.use_sync_weight_loader():
|
||||
return original_weight_loader
|
||||
|
||||
def _synced_weight_loader(param, *args, **kwargs):
|
||||
out = original_weight_loader(param, *args, **kwargs)
|
||||
if param.device != torch.device("cpu"):
|
||||
torch._sync(param)
|
||||
return out
|
||||
|
||||
return _synced_weight_loader
|
||||
|
||||
@classmethod
|
||||
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
|
||||
"""
|
||||
Returns a mapping from device_type to a tuple of supported
|
||||
kv_buffer_device for nixl.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_nixl_memory_type(cls) -> Optional[str]:
|
||||
"""
|
||||
Returns the nixl memory type for the current platform.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
device_type = ""
|
||||
497
vllm/platforms/rocm.py
Normal file
497
vllm/platforms/rocm.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import is_nccl_available
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
|
||||
amdsmi_get_processor_handles, amdsmi_init,
|
||||
amdsmi_shut_down, amdsmi_topo_get_link_type)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from amdsmi with %r", e)
|
||||
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
try:
|
||||
import vllm._rocm_C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._rocm_C with %r", e)
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS: list[str] = []
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
|
||||
"Triton flash attention. For half-precision SWA support, "
|
||||
"please use CK flash attention by setting "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
|
||||
"Qwen2ForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"MistralForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"MixtralForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"PaliGemmaForConditionalGeneration":
|
||||
("ROCm flash attention does not yet "
|
||||
"fully support 32-bit precision on PaliGemma"),
|
||||
"Phi3VForCausalLM":
|
||||
("ROCm Triton flash attention may run into compilation errors due to "
|
||||
"excessive use of shared memory. If this happens, disable Triton FA "
|
||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
}
|
||||
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
||||
"0x74a0": "AMD_Instinct_MI300A",
|
||||
"0x74a1": "AMD_Instinct_MI300X",
|
||||
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF
|
||||
"0x74a5": "AMD_Instinct_MI325X",
|
||||
"0x74b9": "AMD_Instinct_MI325X", # MI325X VF
|
||||
"0x74a9": "AMD_Instinct_MI300X_HF",
|
||||
"0x74bd": "AMD_Instinct_MI300X_HF",
|
||||
}
|
||||
|
||||
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
|
||||
if "HIP_VISIBLE_DEVICES" in os.environ:
|
||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
|
||||
assert val == cuda_val
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = val
|
||||
|
||||
# AMDSMI utils
|
||||
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
# the major benefit of using AMDSMI is that it will not initialize CUDA
|
||||
|
||||
|
||||
def with_amdsmi_context(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
amdsmi_init()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx1x() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_mi3xx() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx9() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx950() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def use_rocm_custom_paged_attention(
|
||||
qtype: torch.dtype,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int,
|
||||
kv_cache_dtype: str,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
sinks: Optional[torch.Tensor] = None) -> bool:
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||
# disabled due to observed numerical discrepancy.
|
||||
if ON_GFX9:
|
||||
return ((not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
|
||||
|
||||
else:
|
||||
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and head_size == 128 and block_size == 16
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024 and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
device_name: str = "rocm"
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
ray_device_key: str = "GPU"
|
||||
dist_backend: str = "nccl"
|
||||
# rocm shares the same device control env var as CUDA
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
|
||||
"quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> _Backend:
|
||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
||||
and on_gfx9()):
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return _Backend.ROCM_AITER_FA
|
||||
if on_gfx9():
|
||||
return _Backend.FLASH_ATTN
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on ROCm.")
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"MLA attention backends require the V1 engine. "
|
||||
"Set VLLM_USE_V1=1 to enable them.")
|
||||
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
|
||||
is_aiter_mla_enabled)
|
||||
|
||||
if selected_backend is None:
|
||||
selected_backend = (_Backend.ROCM_AITER_MLA if
|
||||
is_aiter_mla_enabled() or block_size == 1
|
||||
else _Backend.TRITON_MLA)
|
||||
|
||||
if selected_backend == _Backend.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}.")
|
||||
if selected_backend == _Backend.ROCM_AITER_MLA:
|
||||
if block_size == 1:
|
||||
logger.info("Using AITER MLA backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
"(currently only supports block size 1)")
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend.")
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
|
||||
and on_gfx9():
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"rocm_aiter_fa.AiterFlashAttentionBackend")
|
||||
elif (envs.VLLM_ROCM_USE_AITER and
|
||||
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
|
||||
selected_backend == _Backend.ROCM_ATTN:
|
||||
# rocm specific backend, with aiter and/or
|
||||
# triton prefix-prefill
|
||||
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"rocm_attn.RocmAttentionBackend")
|
||||
else:
|
||||
# default case, using triton unified attention
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend.")
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(cls,
|
||||
device_id: int = 0
|
||||
) -> Optional[DeviceCapability]:
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@classmethod
|
||||
@with_amdsmi_context
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
"""
|
||||
Query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [
|
||||
amdsmi_get_processor_handles()[i] for i in physical_device_ids
|
||||
]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(
|
||||
handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.",
|
||||
exc_info=error)
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@with_amdsmi_context
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = amdsmi_get_processor_handles()[physical_device_id]
|
||||
asic_info = amdsmi_get_gpu_asic_info(handle)
|
||||
device_name: str = asic_info["device_id"]
|
||||
if device_name in _ROCM_DEVICE_ID_NAME_MAP:
|
||||
return _ROCM_DEVICE_ID_NAME_MAP[device_name]
|
||||
return asic_info["market_name"]
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
|
||||
envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if vllm_config.speculative_config:
|
||||
if not use_v1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not supported on vLLM V0.")
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
if use_v1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
# Aiter rms norm perform best when CUDA Graph capture is enabled.
|
||||
if (use_v1 and use_aiter_rms_norm and not is_eager_execution
|
||||
and "-rms_norm" not in compilation_config.custom_ops):
|
||||
compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(f"Model architecture '{model_arch}' is not "
|
||||
"supported by ROCm for now.")
|
||||
|
||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
|
||||
logger.warning(
|
||||
"Model architecture '%s' is partially "
|
||||
"supported by ROCm: %s", model_arch, msg)
|
||||
|
||||
@classmethod
|
||||
def verify_quantization(cls, quant: str) -> None:
|
||||
super().verify_quantization(quant)
|
||||
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
|
||||
logger.warning(
|
||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||
" is not set, enabling VLLM_USE_TRITON_AWQ.")
|
||||
envs.VLLM_USE_TRITON_AWQ = True
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
|
||||
device)[0]
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_mx(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
if cls.is_fp8_fnuz():
|
||||
return torch.float8_e4m3fnuz
|
||||
else:
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
@classmethod
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
# We only enable custom allreduce for MI300 series
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
supported_archs = ['gfx94', 'gfx95']
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(
|
||||
device_id).multi_processor_count
|
||||
|
||||
@classmethod
|
||||
def is_navi(cls) -> bool:
|
||||
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
assert is_nccl_available()
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not cls.has_device_capability(80):
|
||||
capability = cls.get_device_capability()
|
||||
gpu_name = cls.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs "
|
||||
"with compute capability of at least 8.0. "
|
||||
f"Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
233
vllm/platforms/tpu.py
Normal file
233
vllm/platforms/tpu.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from tpu_info import device
|
||||
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
||||
from vllm.pooling_params import PoolingParams
|
||||
else:
|
||||
BlockSize = None
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
USE_TPU_COMMONS = False
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
device_name: str = "tpu"
|
||||
device_type: str = "tpu"
|
||||
dispatch_key: str = "XLA"
|
||||
ray_device_key: str = "TPU"
|
||||
dist_backend: str = "gloo"
|
||||
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
||||
simple_compile_backend: str = "openxla"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"fp8", "tpu_int8", "compressed-tensors"
|
||||
]
|
||||
|
||||
additional_env_vars: list[str] = [
|
||||
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
if not use_v1:
|
||||
raise ValueError("TPU backend only supports V1.")
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.tpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
chip_type, _ = device.get_local_chips()
|
||||
return f"TPU {chip_type.name}"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
|
||||
return torch.finfo(dtype).min, torch.finfo(dtype).max
|
||||
|
||||
@classmethod
|
||||
def can_update_inplace(cls):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_lora_vocab_padding_size(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
# For v0, the default block size is 16.
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = cast(BlockSize, 16)
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
# TPU only supports DYNAMO_ONCE compilation level
|
||||
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
|
||||
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and "
|
||||
"disabling cudagraph.")
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
if compilation_config.cudagraph_mode is None or \
|
||||
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
||||
!= CUDAGraphMode.NONE:
|
||||
logger.info("[TPU] CUDA graph is not supported on TPU, "
|
||||
"disabling cudagraphs.")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
if compilation_config.backend == "":
|
||||
compilation_config.backend = "openxla"
|
||||
|
||||
assert vllm_config.speculative_config is None, \
|
||||
"TPU does not support speculative decoding"
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
if model_config is not None and model_config.dtype in (torch.float16,
|
||||
torch.float32):
|
||||
logger.warning(
|
||||
"The TPU backend currently does not support %s. "
|
||||
"Using bfloat16 instead.", model_config.dtype)
|
||||
model_config.dtype = torch.bfloat16
|
||||
|
||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
||||
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
||||
vllm_config) # type: ignore[assignment]
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
|
||||
if scheduler_config.is_multimodal_model and not \
|
||||
scheduler_config.disable_chunked_mm_input:
|
||||
logger.warning("TPU does not support running Multimodal models"\
|
||||
" without setting `--disable_chunked_mm_input`. " \
|
||||
"Forcing --disable_chunked_mm_input.")
|
||||
scheduler_config.disable_chunked_mm_input = True
|
||||
|
||||
if model_config and model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on TPU.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def use_all_gather(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_request(
|
||||
cls,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
processed_inputs: ProcessorInputs,
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
if (isinstance(params, SamplingParams)
|
||||
and params.sampling_type == SamplingType.RANDOM_SEED):
|
||||
raise ValueError("Torch XLA does not support per-request seed.")
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(
|
||||
dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
""" tpu blocks to cpu blocks"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
||||
|
||||
@classmethod
|
||||
def use_sync_weight_loader(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
try:
|
||||
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||
TpuPlatform = TpuCommonsPlatform # type: ignore
|
||||
USE_TPU_COMMONS = True
|
||||
except ImportError:
|
||||
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
|
||||
pass
|
||||
243
vllm/platforms/xpu.py
Normal file
243
vllm/platforms/xpu.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUPlatform(Platform):
|
||||
_enum = PlatformEnum.XPU
|
||||
device_name: str = "xpu"
|
||||
device_type: str = "xpu"
|
||||
dispatch_key: str = "XPU"
|
||||
# Intel XPU's device key is "GPU" for Ray.
|
||||
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
|
||||
ray_device_key: str = "GPU"
|
||||
dist_backend: str = "ccl" # ccl | xccl
|
||||
device_control_env_var: str = "ZE_AFFINITY_MASK"
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool, use_sparse) -> str:
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on XPU.")
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
if not use_v1:
|
||||
raise ValueError("XPU backend only supports V1.")
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
if selected_backend == _Backend.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return FLASH_ATTN
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}, "
|
||||
f"with use_v1: {use_v1} use_mla: {use_mla}")
|
||||
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Check if the kv_cache_dtype is supported.
|
||||
XPU only support fp8 kv cache with triton backend.
|
||||
"""
|
||||
if envs.is_set("VLLM_ATTENTION_BACKEND") and \
|
||||
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN":
|
||||
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
device_id: int = 0,
|
||||
) -> Optional[DeviceCapability]:
|
||||
# capacity format differs from cuda's and will cause unexpected
|
||||
# failure, so use None directly
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.xpu.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
# in V1(or with ipex chunked prefill) block_size is 64
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 64
|
||||
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config.compile_sizes is None:
|
||||
compilation_config.compile_sizes = []
|
||||
|
||||
assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \
|
||||
"CUDA graph mode should be NONE on XPU"
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
||||
|
||||
if parallel_config.distributed_executor_backend is None:
|
||||
if parallel_config.world_size > 1:
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
else:
|
||||
parallel_config.distributed_executor_backend = "uni"
|
||||
elif parallel_config.distributed_executor_backend == "mp":
|
||||
# FIXME(kunshang):
|
||||
# spawn needs calling `if __name__ == '__main__':``
|
||||
# fork is not supported for xpu start new process.
|
||||
if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
logger.warning(
|
||||
"Please use spawn as start method if you want to use mp.")
|
||||
elif (parallel_config.distributed_executor_backend != "ray"
|
||||
and parallel_config.distributed_executor_backend != "uni"
|
||||
and parallel_config.distributed_executor_backend
|
||||
!= "external_launcher"):
|
||||
logger.warning(
|
||||
"%s is not supported on XPU, fallback to ray distributed"
|
||||
" executor backend.",
|
||||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
|
||||
if model_config and model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout("NHD")
|
||||
logger.info("Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
|
||||
"only NHD layout is supported by XPU attention kernels.")
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.xpu.reset_peak_memory_stats(device)
|
||||
return torch.xpu.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
return torch.float8_e5m2
|
||||
|
||||
@classmethod
|
||||
def is_data_center_gpu(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
return device_name.count("data center gpu") > 0
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return torch.xpu.device_count()
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
device_name = cls.get_device_name().lower()
|
||||
# client gpu a770
|
||||
if device_name.count("a770") > 0:
|
||||
raise ValueError(
|
||||
"Intel Arc A770 have bfloat16 accuracy known issue. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from src_cache to dst_cache on XPU."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from XPU to host (CPU)."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||
Reference in New Issue
Block a user