Sync from v0.13
This commit is contained in:
277
vllm/platforms/__init__.py
Normal file
277
vllm/platforms/__init__.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import traceback
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm import envs
|
||||
from vllm.plugins import PLATFORM_PLUGINS_GROUP, load_plugins_by_group
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import supports_xccl
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def vllm_version_matches_substr(substr: str) -> bool:
|
||||
"""
|
||||
Check to see if the vLLM version matches a substring.
|
||||
"""
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
try:
|
||||
vllm_version = version("vllm")
|
||||
except PackageNotFoundError as e:
|
||||
logger.warning(
|
||||
"The vLLM package was not found, so its version could not be "
|
||||
"inspected. This may cause platform detection to fail."
|
||||
)
|
||||
raise e
|
||||
return substr in vllm_version
|
||||
|
||||
|
||||
def tpu_platform_plugin() -> str | None:
|
||||
logger.debug("Checking if TPU platform is available.")
|
||||
|
||||
# Check for Pathways TPU proxy
|
||||
if envs.VLLM_TPU_USING_PATHWAYS:
|
||||
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
|
||||
return "tpu_inference.platforms.tpu_platform.TpuPlatform"
|
||||
|
||||
# Check for libtpu installation
|
||||
try:
|
||||
# While it's technically possible to install libtpu on a
|
||||
# non-TPU machine, this is a very uncommon scenario. Therefore,
|
||||
# we assume that libtpu is installed only if the machine
|
||||
# has TPUs.
|
||||
|
||||
import libtpu # noqa: F401
|
||||
|
||||
logger.debug("Confirmed TPU platform is available.")
|
||||
return "vllm.platforms.tpu.TpuPlatform"
|
||||
except Exception as e:
|
||||
logger.debug("TPU platform is not available because: %s", str(e))
|
||||
return None
|
||||
|
||||
|
||||
def cuda_platform_plugin() -> str | None:
|
||||
is_cuda = False
|
||||
logger.debug("Checking if CUDA platform is available.")
|
||||
try:
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
|
||||
pynvml = import_pynvml()
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
# NOTE: Edge case: vllm cpu build on a GPU machine.
|
||||
# Third-party pynvml can be imported in cpu build,
|
||||
# we need to check if vllm is built with cpu too.
|
||||
# Otherwise, vllm will always activate cuda plugin
|
||||
# on a GPU machine, even if in a cpu build.
|
||||
is_cuda = (
|
||||
pynvml.nvmlDeviceGetCount() > 0
|
||||
and not vllm_version_matches_substr("cpu")
|
||||
)
|
||||
if pynvml.nvmlDeviceGetCount() <= 0:
|
||||
logger.debug("CUDA platform is not available because no GPU is found.")
|
||||
if vllm_version_matches_substr("cpu"):
|
||||
logger.debug(
|
||||
"CUDA platform is not available because vLLM is built with CPU."
|
||||
)
|
||||
if is_cuda:
|
||||
logger.debug("Confirmed CUDA platform is available.")
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
except Exception as e:
|
||||
logger.debug("Exception happens when checking CUDA platform: %s", str(e))
|
||||
if "nvml" not in e.__class__.__name__.lower():
|
||||
# If the error is not related to NVML, re-raise it.
|
||||
raise e
|
||||
|
||||
# CUDA is supported on Jetson, but NVML may not be.
|
||||
import os
|
||||
|
||||
def cuda_is_jetson() -> bool:
|
||||
return os.path.isfile("/etc/nv_tegra_release") or os.path.exists(
|
||||
"/sys/class/tegra-firmware"
|
||||
)
|
||||
|
||||
if cuda_is_jetson():
|
||||
logger.debug("Confirmed CUDA platform is available on Jetson.")
|
||||
is_cuda = True
|
||||
else:
|
||||
logger.debug("CUDA platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None
|
||||
|
||||
|
||||
def rocm_platform_plugin() -> str | None:
|
||||
is_rocm = False
|
||||
logger.debug("Checking if ROCm platform is available.")
|
||||
try:
|
||||
import amdsmi
|
||||
|
||||
amdsmi.amdsmi_init()
|
||||
try:
|
||||
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
|
||||
is_rocm = True
|
||||
logger.debug("Confirmed ROCm platform is available.")
|
||||
else:
|
||||
logger.debug("ROCm platform is not available because no GPU is found.")
|
||||
finally:
|
||||
amdsmi.amdsmi_shut_down()
|
||||
except Exception as e:
|
||||
logger.debug("ROCm platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
|
||||
|
||||
|
||||
def xpu_platform_plugin() -> str | None:
|
||||
is_xpu = False
|
||||
logger.debug("Checking if XPU platform is available.")
|
||||
try:
|
||||
# installed IPEX if the machine has XPUs.
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import torch
|
||||
|
||||
if supports_xccl():
|
||||
dist_backend = "xccl"
|
||||
else:
|
||||
dist_backend = "ccl"
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
is_xpu = True
|
||||
from vllm.platforms.xpu import XPUPlatform
|
||||
|
||||
XPUPlatform.dist_backend = dist_backend
|
||||
logger.debug("Confirmed %s backend is available.", XPUPlatform.dist_backend)
|
||||
logger.debug("Confirmed XPU platform is available.")
|
||||
except Exception as e:
|
||||
logger.debug("XPU platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
|
||||
|
||||
|
||||
def cpu_platform_plugin() -> str | None:
|
||||
is_cpu = False
|
||||
logger.debug("Checking if CPU platform is available.")
|
||||
try:
|
||||
is_cpu = vllm_version_matches_substr("cpu")
|
||||
if is_cpu:
|
||||
logger.debug(
|
||||
"Confirmed CPU platform is available because vLLM is built with CPU."
|
||||
)
|
||||
if not is_cpu:
|
||||
import sys
|
||||
|
||||
is_cpu = sys.platform.startswith("darwin")
|
||||
if is_cpu:
|
||||
logger.debug(
|
||||
"Confirmed CPU platform is available because the machine is MacOS."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("CPU platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
|
||||
|
||||
|
||||
builtin_platform_plugins = {
|
||||
"tpu": tpu_platform_plugin,
|
||||
"cuda": cuda_platform_plugin,
|
||||
"rocm": rocm_platform_plugin,
|
||||
"xpu": xpu_platform_plugin,
|
||||
"cpu": cpu_platform_plugin,
|
||||
}
|
||||
|
||||
|
||||
def resolve_current_platform_cls_qualname() -> str:
|
||||
platform_plugins = load_plugins_by_group(PLATFORM_PLUGINS_GROUP)
|
||||
|
||||
activated_plugins = []
|
||||
|
||||
for name, func in chain(builtin_platform_plugins.items(), platform_plugins.items()):
|
||||
try:
|
||||
assert callable(func)
|
||||
platform_cls_qualname = func()
|
||||
if platform_cls_qualname is not None:
|
||||
activated_plugins.append(name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
activated_builtin_plugins = list(
|
||||
set(activated_plugins) & set(builtin_platform_plugins.keys())
|
||||
)
|
||||
activated_oot_plugins = list(set(activated_plugins) & set(platform_plugins.keys()))
|
||||
|
||||
if len(activated_oot_plugins) >= 2:
|
||||
raise RuntimeError(
|
||||
"Only one platform plugin can be activated, but got: "
|
||||
f"{activated_oot_plugins}"
|
||||
)
|
||||
elif len(activated_oot_plugins) == 1:
|
||||
platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
|
||||
logger.info("Platform plugin %s is activated", activated_oot_plugins[0])
|
||||
elif len(activated_builtin_plugins) >= 2:
|
||||
raise RuntimeError(
|
||||
"Only one platform plugin can be activated, but got: "
|
||||
f"{activated_builtin_plugins}"
|
||||
)
|
||||
elif len(activated_builtin_plugins) == 1:
|
||||
platform_cls_qualname = builtin_platform_plugins[activated_builtin_plugins[0]]()
|
||||
logger.debug(
|
||||
"Automatically detected platform %s.", activated_builtin_plugins[0]
|
||||
)
|
||||
else:
|
||||
platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform"
|
||||
logger.debug("No platform detected, vLLM is running on UnspecifiedPlatform")
|
||||
return platform_cls_qualname
|
||||
|
||||
|
||||
_current_platform = None
|
||||
_init_trace: str = ""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
current_platform: Platform
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "current_platform":
|
||||
# lazy init current_platform.
|
||||
# 1. out-of-tree platform plugins need `from vllm.platforms import
|
||||
# Platform` so that they can inherit `Platform` class. Therefore,
|
||||
# we cannot resolve `current_platform` during the import of
|
||||
# `vllm.platforms`.
|
||||
# 2. when users use out-of-tree platform plugins, they might run
|
||||
# `import vllm`, some vllm internal code might access
|
||||
# `current_platform` during the import, and we need to make sure
|
||||
# `current_platform` is only resolved after the plugins are loaded
|
||||
# (we have tests for this, if any developer violate this, they will
|
||||
# see the test failures).
|
||||
global _current_platform
|
||||
if _current_platform is None:
|
||||
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
||||
_current_platform = resolve_obj_by_qualname(platform_cls_qualname)()
|
||||
global _init_trace
|
||||
_init_trace = "".join(traceback.format_stack())
|
||||
return _current_platform
|
||||
elif name in globals():
|
||||
return globals()[name]
|
||||
else:
|
||||
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
||||
|
||||
|
||||
def __setattr__(name: str, value):
|
||||
if name == "current_platform":
|
||||
global _current_platform
|
||||
_current_platform = value
|
||||
elif name in globals():
|
||||
globals()[name] = value
|
||||
else:
|
||||
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
||||
|
||||
|
||||
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
|
||||
421
vllm/platforms/cpu.py
Normal file
421
vllm/platforms/cpu.py
Normal file
@@ -0,0 +1,421 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import psutil
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
def get_max_threads(pid=0):
|
||||
if hasattr(os, "sched_getaffinity"):
|
||||
return len(os.sched_getaffinity(pid))
|
||||
elif platform.system() == "Darwin":
|
||||
return os.cpu_count()
|
||||
else:
|
||||
raise NotImplementedError("Unsupported OS")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogicalCPUInfo:
|
||||
id: int = -1
|
||||
physical_core: int = -1
|
||||
numa_node: int = -1
|
||||
|
||||
@classmethod
|
||||
def _int(cls, value: str) -> int:
|
||||
try:
|
||||
int_value = int(value)
|
||||
except Exception:
|
||||
int_value = -1
|
||||
return int_value
|
||||
|
||||
@staticmethod
|
||||
def json_decoder(obj_dict: dict):
|
||||
id = obj_dict.get("cpu")
|
||||
physical_core = obj_dict.get("core")
|
||||
numa_node = obj_dict.get("node")
|
||||
|
||||
if not (id is None or physical_core is None or numa_node is None):
|
||||
return LogicalCPUInfo(
|
||||
id=LogicalCPUInfo._int(id),
|
||||
physical_core=LogicalCPUInfo._int(physical_core),
|
||||
numa_node=LogicalCPUInfo._int(numa_node),
|
||||
)
|
||||
else:
|
||||
return obj_dict
|
||||
|
||||
|
||||
class CpuPlatform(Platform):
|
||||
_enum = PlatformEnum.CPU
|
||||
device_name: str = "cpu"
|
||||
device_type: str = "cpu"
|
||||
dispatch_key: str = "CPU"
|
||||
dist_backend: str = "gloo"
|
||||
device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
|
||||
return [torch.bfloat16, torch.float32]
|
||||
elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith(
|
||||
"darwin"
|
||||
):
|
||||
if (
|
||||
subprocess.check_output(
|
||||
["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True
|
||||
).strip()
|
||||
== b"1"
|
||||
):
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
return [torch.float16, torch.float32]
|
||||
elif self.get_cpu_architecture() == CpuArchEnum.RISCV:
|
||||
# Workaround for Issue #25655: RISC-V scheduler bug with float16
|
||||
#
|
||||
# Background:
|
||||
# - RISC-V currently uses scalar code path
|
||||
# - There is a latent bug in the vLLM scheduler that provides
|
||||
# invalid
|
||||
# physical_block_idx values under certain conditions
|
||||
# - This bug causes segmentation faults when using float16
|
||||
# dtype on RISC-V
|
||||
# - Testing shows that forcing float32 successfully bypasses
|
||||
# this issue
|
||||
#
|
||||
# Technical details:
|
||||
# - The bug manifests as out-of-bounds physical_block_idx in
|
||||
# block_tables
|
||||
# - Only occurs on RISC-V hardware
|
||||
# tested on Sophgo SG2044
|
||||
# - Does not reproduce on x86 or other architectures
|
||||
# - Root cause is in Python-level scheduling logic,
|
||||
# not C++ kernels
|
||||
#
|
||||
# This is a temporary workaround until the scheduler bug is fixed.
|
||||
# See: https://github.com/vllm-project/vllm/issues/25655
|
||||
return [torch.float32]
|
||||
# x86/aarch64 CPU has supported both bf16 and fp16 natively.
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if attn_selector_config.use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
if attn_selector_config.use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on CPU.")
|
||||
return AttentionBackendEnum.CPU_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
|
||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||
node_dir = "/sys/devices/system/node"
|
||||
if kv_cache_space is None:
|
||||
nodes = (
|
||||
[d for d in os.listdir(node_dir) if d.startswith("node")]
|
||||
if os.path.exists(node_dir)
|
||||
else []
|
||||
)
|
||||
num_numa_nodes = len(nodes) or 1
|
||||
free_cpu_memory = psutil.virtual_memory().total // num_numa_nodes
|
||||
DEFAULT_CPU_MEM_UTILIZATION = 0.5
|
||||
kv_cache_space = int(free_cpu_memory * DEFAULT_CPU_MEM_UTILIZATION)
|
||||
kv_cache_space_gib = kv_cache_space / GiB_bytes
|
||||
logger.warning_once(
|
||||
"VLLM_CPU_KVCACHE_SPACE not set. Using "
|
||||
f"{kv_cache_space_gib:.2f} GiB for KV cache."
|
||||
)
|
||||
else:
|
||||
kv_cache_space *= GiB_bytes
|
||||
|
||||
return kv_cache_space
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if model_config is not None:
|
||||
model_config.disable_cascade_attn = True
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
if cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
|
||||
if cache_config.block_size % 32 != 0:
|
||||
logger.warning(
|
||||
"CPU backend prefers block_size is multiples of 32, "
|
||||
"otherwise the performance is not optimized."
|
||||
)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if (
|
||||
scheduler_config.enable_chunked_prefill
|
||||
or cache_config.enable_prefix_caching
|
||||
) and cache_config.cache_dtype != "auto":
|
||||
raise RuntimeError(
|
||||
"Chunked-prefill and prefix-cache on the CPU "
|
||||
"backend is not compatible with FP8 KV cache."
|
||||
)
|
||||
|
||||
if cache_config.cache_dtype != "auto":
|
||||
logger.warning(
|
||||
"CPU backend doesn't support KV cache quantization fallback to auto."
|
||||
)
|
||||
cache_config.cache_dtype = "auto"
|
||||
|
||||
cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if (
|
||||
parallel_config.world_size > 1
|
||||
and parallel_config.distributed_executor_backend is not None
|
||||
and parallel_config.distributed_executor_backend != "mp"
|
||||
):
|
||||
logger.warning(
|
||||
(
|
||||
"%s is not supported on CPU, fallback to mp "
|
||||
"distributed executor backend."
|
||||
),
|
||||
parallel_config.distributed_executor_backend,
|
||||
)
|
||||
parallel_config.distributed_executor_backend = "mp"
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
|
||||
# Disable DBO
|
||||
if parallel_config.enable_dbo:
|
||||
logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.")
|
||||
parallel_config.enable_dbo = False
|
||||
|
||||
# Note: workaround for v1 gpu_model_runner
|
||||
from vllm.config import CompilationMode
|
||||
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = []
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
# Note: vLLM V1 is using PIECEWISE level compilation, which will
|
||||
# take time to compile kernels just-in-time with the inductor
|
||||
# backend. For CPU CI tests, most of them are executed fast and
|
||||
# compilations consume too much time, even with torch compile
|
||||
# cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
|
||||
# and just execute model with dynamo + eager mode to save time.
|
||||
# VLLM_CPU_CI_ENV is only used as an internal variable.
|
||||
if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
|
||||
backend = "eager"
|
||||
else:
|
||||
backend = "inductor"
|
||||
|
||||
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||
compilation_config.backend = backend
|
||||
compilation_config.inductor_compile_config.update(
|
||||
{
|
||||
"dce": True,
|
||||
"size_asserts": False,
|
||||
"nan_asserts": False,
|
||||
"epilogue_fusion": True,
|
||||
}
|
||||
)
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
assert vllm_config.device_config.device_type == "cpu"
|
||||
|
||||
#
|
||||
# Environment variables for CPU executor
|
||||
#
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
# Note: to avoid the error 'nthreads cannot be larger than environment
|
||||
# variable "NUMEXPR_MAX_THREADS" (64)'.
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
|
||||
|
||||
if envs.VLLM_CPU_OMP_THREADS_BIND != "nobind":
|
||||
# Set default threads num for OpenMP parallel
|
||||
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
|
||||
else:
|
||||
# In this case, setting the OpenMP configuration via
|
||||
# OMP_NUM_THREADS is up to the user.
|
||||
logger.info("Disabling binding processes to CPU cores...")
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
# Disable multi-stream for shared experts as no Stream on CPU
|
||||
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
||||
|
||||
# Intel OpenMP setting
|
||||
ld_preload_str = os.getenv("LD_PRELOAD", "")
|
||||
if "libiomp5.so" in ld_preload_str:
|
||||
# The time(milliseconds) that a thread should wait after
|
||||
# completing the execution of a parallel region, before sleeping.
|
||||
os.environ["KMP_BLOCKTIME"] = "1"
|
||||
# Prevents the CPU to run into low performance state
|
||||
os.environ["KMP_TPAUSE"] = "0"
|
||||
# Provides fine granularity parallelism
|
||||
os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
|
||||
os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
|
||||
os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
|
||||
|
||||
if (
|
||||
platform.system() == "Linux"
|
||||
and Platform.get_cpu_architecture()
|
||||
in (CpuArchEnum.ARM, CpuArchEnum.POWERPC)
|
||||
and not ("libomp" in ld_preload_str or "libgomp" in ld_preload_str)
|
||||
):
|
||||
# We need to LD_PRELOAD PyTorch's libgomp, otherwise only
|
||||
# one core will be properly utilized when we thread-bind
|
||||
# See: https://github.com/vllm-project/vllm/issues/27369
|
||||
# TODO: Remove once:
|
||||
# https://github.com/pytorch/pytorch/issues/166087 is fixed
|
||||
|
||||
# We need to find the location of PyTorch's libgomp
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
# Search both torch.libs and torch/lib - See: https://github.com/vllm-project/vllm/issues/30470
|
||||
torch_libs_paths = [
|
||||
os.path.join(site_root, "torch.libs"),
|
||||
os.path.join(torch_pkg, "lib"),
|
||||
]
|
||||
pytorch_libgomp_so_candidates = []
|
||||
for torch_libs in torch_libs_paths:
|
||||
pytorch_libgomp_so_candidates.extend(
|
||||
glob.glob(os.path.join(torch_libs, "libgomp*.so*"))
|
||||
)
|
||||
if pytorch_libgomp_so_candidates:
|
||||
pytorch_libgomp_so = pytorch_libgomp_so_candidates[0]
|
||||
if ld_preload_str:
|
||||
ld_preload_str += ":"
|
||||
ld_preload_str += pytorch_libgomp_so
|
||||
os.environ["LD_PRELOAD"] = ld_preload_str
|
||||
|
||||
# To hint IPEX uses shared memory based AllReduce
|
||||
os.environ["LOCAL_WORLD_SIZE"] = str(
|
||||
vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
if model_config is not None and model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled."
|
||||
)
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.model_config.max_model_len,
|
||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]:
|
||||
assert platform.system() == "Linux"
|
||||
|
||||
# Init LogicalCPUInfo from lscpu
|
||||
lscpu_output = subprocess.check_output(
|
||||
"lscpu -J -e=CPU,CORE,NODE", shell=True, text=True
|
||||
)
|
||||
lscpu_output = re.sub(r'"node":\s*-\s*(,|\n)', r'"node": 0\1', lscpu_output)
|
||||
logical_cpu_list: list[LogicalCPUInfo] = json.loads(
|
||||
lscpu_output, object_hook=LogicalCPUInfo.json_decoder
|
||||
)["cpus"]
|
||||
|
||||
# Filter CPUs with invalid attributes
|
||||
logical_cpu_list = [
|
||||
x
|
||||
for x in logical_cpu_list
|
||||
if -1 not in (x.id, x.physical_core, x.numa_node)
|
||||
]
|
||||
|
||||
# Filter allowed CPUs
|
||||
if hasattr(os, "sched_getaffinity"):
|
||||
allowed_cpu_id_list = os.sched_getaffinity(0)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported OS")
|
||||
logical_cpu_list = [x for x in logical_cpu_list if x.id in allowed_cpu_id_list]
|
||||
|
||||
# Get allowed NUMA nodes
|
||||
allowed_numa_nodes = set()
|
||||
for x in logical_cpu_list:
|
||||
allowed_numa_nodes.add(x.numa_node) # type: ignore
|
||||
allowed_numa_nodes_list = sorted(allowed_numa_nodes)
|
||||
|
||||
env_key = CpuPlatform.device_control_env_var
|
||||
if env_key in os.environ and os.environ[env_key] != "":
|
||||
visible_nodes = [int(s) for s in os.environ[env_key].split(",")]
|
||||
allowed_numa_nodes_list = [
|
||||
x for x in visible_nodes if x in allowed_cpu_id_list
|
||||
]
|
||||
|
||||
return allowed_numa_nodes_list, logical_cpu_list
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
"""
|
||||
Get device specific communicator class for distributed communication.
|
||||
"""
|
||||
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_structured_output(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
618
vllm/platforms/cuda.py
Normal file
618
vllm/platforms/cuda.py
Normal file
@@ -0,0 +1,618 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Code inside this file can safely assume cuda platform, e.g. importing
|
||||
pynvml. However, it should not initialize cuda context.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache, wraps
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
else:
|
||||
VllmConfig = None
|
||||
CacheDType = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
pynvml = import_pynvml()
|
||||
|
||||
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
||||
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
|
||||
|
||||
@cache
|
||||
def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||
if use_mla:
|
||||
if device_capability.major == 10:
|
||||
return [
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
if device_capability.major == 10:
|
||||
return [
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
]
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CudaPlatformBase(Platform):
|
||||
_enum = PlatformEnum.CUDA
|
||||
device_name: str = "cuda"
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
ray_device_key: str = "GPU"
|
||||
dist_backend: str = "nccl"
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
if self.has_device_capability(80):
|
||||
# Ampere and Hopper or later NVIDIA GPUs.
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
if self.has_device_capability(60):
|
||||
# Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
|
||||
return [torch.float16, torch.float32]
|
||||
# Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
|
||||
# though vLLM doesn't support these GPUs.
|
||||
return [torch.float32]
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cuda.set_device(device)
|
||||
# With this trick we can force the device to be set eagerly
|
||||
# see https://github.com/pytorch/pytorch/issues/155668
|
||||
# for why and when it is needed
|
||||
_ = torch.zeros(1, device=device)
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_fully_connected(cls, device_ids: list[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def log_warnings(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
# TODO(lucas): handle this more gracefully
|
||||
# Note: model_config may be None during testing
|
||||
# Note: block_size is initialized in
|
||||
# HybridAttentionMambaModelConfig.verify_and_update_config
|
||||
# for models with both attention and mamba,
|
||||
# and doesn't need to be reinitialized here
|
||||
if (
|
||||
model_config is not None
|
||||
and model_config.use_mla
|
||||
and cache_config.block_size is not None
|
||||
):
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
# If `--attention-config.backend` is not set and we are using MLA,
|
||||
# then we default to FlashMLA backend for non-blackwell GPUs,
|
||||
# else we default to CutlassMLA. For each case, we force the
|
||||
# required block_size.
|
||||
use_flashmla = False
|
||||
use_cutlass_mla = False
|
||||
use_flashinfer_mla = False
|
||||
|
||||
if vllm_config.attention_config.backend is None:
|
||||
# Default case
|
||||
if cls.is_device_capability_family(100) and not use_sparse:
|
||||
# Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2).
|
||||
use_cutlass_mla = True
|
||||
# Set the backend in AttentionConfig so it's used during
|
||||
# backend selection
|
||||
vllm_config.attention_config.backend = (
|
||||
AttentionBackendEnum.CUTLASS_MLA
|
||||
)
|
||||
else:
|
||||
# Not Blackwell
|
||||
use_flashmla = True
|
||||
else:
|
||||
# Forced case
|
||||
backend = vllm_config.attention_config.backend
|
||||
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
|
||||
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
|
||||
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
if (
|
||||
use_flashmla
|
||||
and is_flashmla_dense_supported()[0]
|
||||
and cache_config.block_size % 64 != 0
|
||||
):
|
||||
cache_config.block_size = 64
|
||||
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||
|
||||
if use_cutlass_mla and cache_config.block_size % 128 != 0:
|
||||
cache_config.block_size = 128
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
|
||||
)
|
||||
|
||||
if (
|
||||
use_flashinfer_mla
|
||||
and cache_config.block_size != 32
|
||||
and cache_config.block_size % 64 != 0
|
||||
):
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashInferMLA backend."
|
||||
)
|
||||
|
||||
# TODO(Chen): remove this hacky code
|
||||
if use_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||
)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
# Note: model_config may be None during testing
|
||||
if (
|
||||
model_config is not None
|
||||
and model_config.is_mm_prefix_lm
|
||||
and scheduler_config.is_multimodal_model
|
||||
and not scheduler_config.disable_chunked_mm_input
|
||||
):
|
||||
logger.warning(
|
||||
"Forcing --disable_chunked_mm_input for models "
|
||||
"with multimodal-bidirectional attention."
|
||||
)
|
||||
scheduler_config.disable_chunked_mm_input = True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(
|
||||
cls, device: torch.types.Device | None = None
|
||||
) -> float:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def get_valid_backends(
|
||||
cls,
|
||||
device_capability: DeviceCapability,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", list[str]],
|
||||
]:
|
||||
valid_backends_priorities = []
|
||||
invalid_reasons = {}
|
||||
|
||||
backend_priorities = _get_backend_priorities(
|
||||
attn_selector_config.use_mla, device_capability
|
||||
)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
backend_class = backend.get_class()
|
||||
invalid_reasons_i = backend_class.validate_configuration(
|
||||
device_capability=device_capability,
|
||||
**attn_selector_config._asdict(),
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons_i = ["ImportError"]
|
||||
if invalid_reasons_i:
|
||||
invalid_reasons[backend] = invalid_reasons_i
|
||||
else:
|
||||
valid_backends_priorities.append((backend, priority))
|
||||
|
||||
return valid_backends_priorities, invalid_reasons
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
device_capability = cls.get_device_capability()
|
||||
assert device_capability is not None
|
||||
|
||||
attn_selector_config = attn_selector_config._replace(block_size=None)
|
||||
# First try checking just the selected backend, if there is one.
|
||||
if selected_backend is not None:
|
||||
try:
|
||||
backend_class = selected_backend.get_class()
|
||||
invalid_reasons = backend_class.validate_configuration(
|
||||
device_capability=device_capability,
|
||||
**attn_selector_config._asdict(),
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons = ["ImportError"]
|
||||
if invalid_reasons:
|
||||
raise ValueError(
|
||||
f"Selected backend {selected_backend} is not valid for "
|
||||
f"this configuration. Reason: {invalid_reasons}"
|
||||
)
|
||||
else:
|
||||
logger.info("Using %s backend.", selected_backend)
|
||||
return selected_backend.get_path()
|
||||
|
||||
# No selected backend or the selected backend is invalid,
|
||||
# so we try finding a valid backend.
|
||||
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
|
||||
device_capability=device_capability,
|
||||
attn_selector_config=attn_selector_config,
|
||||
)
|
||||
reasons_str = (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f"{backend.name}: [{', '.join(reasons)}]"
|
||||
for backend, reasons in invalid_reasons.items()
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
config_str = attn_selector_config.__repr__()
|
||||
logger.debug_once(
|
||||
f"Some attention backends are not valid for {cls.device_name} with "
|
||||
f"{config_str}. Reasons: {reasons_str}."
|
||||
)
|
||||
if len(valid_backends_priorities) == 0:
|
||||
raise ValueError(
|
||||
f"No valid attention backend found for {cls.device_name} "
|
||||
f"with {config_str}. Reasons: {reasons_str}."
|
||||
)
|
||||
|
||||
# We have found some valid backends. Select the one with the
|
||||
# highest priority.
|
||||
sorted_indices = sorted(
|
||||
range(len(valid_backends_priorities)),
|
||||
key=lambda i: valid_backends_priorities[i][1],
|
||||
)
|
||||
selected_index = sorted_indices[0]
|
||||
selected_backend = valid_backends_priorities[selected_index][0]
|
||||
logger.info_once(
|
||||
"Using %s attention backend out of potential backends: %s",
|
||||
selected_backend.name,
|
||||
tuple(b[0].name for b in valid_backends_priorities),
|
||||
scope="local",
|
||||
)
|
||||
|
||||
return selected_backend.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention. "
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
# Try FlashAttention first
|
||||
if (cc := cls.get_device_capability()) and cc.major >= 8:
|
||||
try:
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
if backend_class.supports_head_size(
|
||||
head_size
|
||||
) and backend_class.supports_dtype(dtype):
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return (
|
||||
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
return cls.has_device_capability(89)
|
||||
|
||||
@classmethod
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
if dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not cls.has_device_capability(80):
|
||||
capability = cls.get_device_capability()
|
||||
gpu_name = cls.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs "
|
||||
"with compute capability of at least 8.0. "
|
||||
f"Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from src_cache to dst_cache on GPU."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from GPU to host (CPU)."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
# the major benefit of using NVML is that it will not initialize CUDA
|
||||
class NvmlCudaPlatform(CudaPlatformBase):
|
||||
@classmethod
|
||||
@cache
|
||||
@with_nvml_context
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
||||
try:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def has_device_capability(
|
||||
cls,
|
||||
capability: tuple[int, int] | int,
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
try:
|
||||
return super().has_device_capability(capability, device_id)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
return cls._get_physical_device_name(physical_device_id)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def get_device_uuid(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
return pynvml.nvmlDeviceGetUUID(handle)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle,
|
||||
peer_handle,
|
||||
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if"
|
||||
" your machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _get_physical_device_name(cls, device_id: int = 0) -> str:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetName(handle)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def log_warnings(cls):
|
||||
device_ids: int = pynvml.nvmlDeviceGetCount()
|
||||
if device_ids > 1:
|
||||
device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
|
||||
if (
|
||||
len(set(device_names)) > 1
|
||||
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
|
||||
):
|
||||
logger.warning(
|
||||
"Detected different devices in the system: %s. Please"
|
||||
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
||||
"avoid unexpected behavior.",
|
||||
", ".join(device_names),
|
||||
)
|
||||
|
||||
|
||||
class NonNvmlCudaPlatform(CudaPlatformBase):
|
||||
@classmethod
|
||||
@cache
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
logger.exception(
|
||||
"NVLink detection not possible, as context support was"
|
||||
" not found. Assuming no NVLink available."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# Autodetect either NVML-enabled or non-NVML platform
|
||||
# based on whether NVML is available.
|
||||
nvml_available = False
|
||||
try:
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
nvml_available = True
|
||||
except Exception:
|
||||
# On Jetson, NVML is not supported.
|
||||
nvml_available = False
|
||||
finally:
|
||||
if nvml_available:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
|
||||
|
||||
CudaPlatform.log_warnings()
|
||||
695
vllm/platforms/interface.py
Normal file
695
vllm/platforms/interface.py
Normal file
@@ -0,0 +1,695 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import enum
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
else:
|
||||
FlexibleArgumentParser = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def in_wsl() -> bool:
|
||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||
return "microsoft" in " ".join(platform.uname()).lower()
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
ROCM = enum.auto()
|
||||
TPU = enum.auto()
|
||||
XPU = enum.auto()
|
||||
CPU = enum.auto()
|
||||
OOT = enum.auto()
|
||||
UNSPECIFIED = enum.auto()
|
||||
|
||||
|
||||
class CpuArchEnum(enum.Enum):
|
||||
X86 = enum.auto()
|
||||
ARM = enum.auto()
|
||||
POWERPC = enum.auto()
|
||||
S390X = enum.auto()
|
||||
RISCV = enum.auto()
|
||||
OTHER = enum.auto()
|
||||
UNKNOWN = enum.auto()
|
||||
|
||||
|
||||
class DeviceCapability(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) < (other.major, other.minor)
|
||||
|
||||
def __le__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) <= (other.major, other.minor)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) == (other.major, other.minor)
|
||||
|
||||
def __ge__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) >= (other.major, other.minor)
|
||||
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) > (other.major, other.minor)
|
||||
|
||||
def as_version_str(self) -> str:
|
||||
return f"{self.major}.{self.minor}"
|
||||
|
||||
def to_int(self) -> int:
|
||||
"""
|
||||
Express device capability as an integer `<major><minor>`.
|
||||
|
||||
It is assumed that the minor version is always a single digit.
|
||||
"""
|
||||
assert 0 <= self.minor < 10
|
||||
return self.major * 10 + self.minor
|
||||
|
||||
|
||||
class Platform:
|
||||
_enum: PlatformEnum
|
||||
device_name: str
|
||||
device_type: str
|
||||
|
||||
# available dispatch keys:
|
||||
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
# available ray device keys:
|
||||
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
|
||||
# empty string means the device does not support ray
|
||||
ray_device_key: str = ""
|
||||
|
||||
# platform-agnostic way to specify the device control environment variable,
|
||||
# .e.g. CUDA_VISIBLE_DEVICES for CUDA.
|
||||
# hint: search for "get_visible_accelerator_ids_env_var" in
|
||||
# https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
|
||||
device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
|
||||
|
||||
# The torch.compile backend for compiling simple and
|
||||
# standalone functions. The default value is "inductor" to keep
|
||||
# the same behavior as PyTorch.
|
||||
# NOTE: for the forward part of the model, vLLM has another separate
|
||||
# compilation strategy.
|
||||
simple_compile_backend: str = "inductor"
|
||||
|
||||
# The backend used for distributed communication.
|
||||
dist_backend: str = ""
|
||||
|
||||
supported_quantization: list[str] = []
|
||||
|
||||
additional_env_vars: list[str] = []
|
||||
|
||||
_global_graph_pool: Any | None = None
|
||||
|
||||
@property
|
||||
def pass_key(self) -> str:
|
||||
"""Inductor config key for the PassManager custom pass"""
|
||||
return "post_grad_custom_post_pass"
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
"""Returns the supported dtypes for the current platform."""
|
||||
# Be careful with the order of the dtypes. The first dtype will
|
||||
# be used as the default dtype fallback for the current platform,
|
||||
# when encountering unsupported dtypes in "auto" dtype.
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def is_cuda(self) -> bool:
|
||||
return self._enum == PlatformEnum.CUDA
|
||||
|
||||
def is_rocm(self) -> bool:
|
||||
return self._enum == PlatformEnum.ROCM
|
||||
|
||||
def is_tpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.TPU
|
||||
|
||||
def is_xpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.XPU
|
||||
|
||||
def is_cpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.CPU
|
||||
|
||||
def is_out_of_tree(self) -> bool:
|
||||
return self._enum == PlatformEnum.OOT
|
||||
|
||||
def is_unspecified(self) -> bool:
|
||||
return self._enum == PlatformEnum.UNSPECIFIED
|
||||
|
||||
def get_max_output_tokens(self, prompt_len: int) -> int:
|
||||
return sys.maxsize
|
||||
|
||||
def is_cuda_alike(self) -> bool:
|
||||
"""Stateless version of [torch.cuda.is_available][]."""
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
def is_sleep_mode_available(self) -> bool:
|
||||
# TODO: Actually only mi3xx has the sleep mode support now
|
||||
# for ROCm, but currently we don't have a way to detect the
|
||||
# exact GPU model statelessly here. So we return True for
|
||||
# all ROCm platforms for now.
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
@classmethod
|
||||
def get_pass_manager_cls(cls) -> str:
|
||||
"""
|
||||
Get the pass manager class for this platform.
|
||||
It will be registered as a custom pass under the current_platform.pass_key.
|
||||
"""
|
||||
return "vllm.compilation.pass_manager.PostGradPassManager"
|
||||
|
||||
@classmethod
|
||||
def get_compile_backend(cls) -> str:
|
||||
"""
|
||||
Get the custom compile backend for current platform.
|
||||
"""
|
||||
return cls.simple_compile_backend
|
||||
|
||||
@classmethod
|
||||
def device_id_to_physical_device_id(cls, device_id: int):
|
||||
# Treat empty device control env var as unset. This is a valid
|
||||
# configuration in Ray setups where the engine is launched in
|
||||
# a CPU-only placement group located on a GPU node.
|
||||
if (
|
||||
cls.device_control_env_var in os.environ
|
||||
and os.environ[cls.device_control_env_var] != ""
|
||||
):
|
||||
device_ids = os.environ[cls.device_control_env_var].split(",")
|
||||
physical_device_id = device_ids[device_id]
|
||||
return int(physical_device_id)
|
||||
else:
|
||||
return device_id
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
"""Import any platform-specific C kernels."""
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C: %r", e)
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
"""
|
||||
Get the vision attention backend class of a device.
|
||||
|
||||
NOTE: ViT Attention should be checked and override in the platform-specific
|
||||
implementation. we should not override this in any other places, like
|
||||
the model_executor/models/<model_name>.py.
|
||||
|
||||
We check if the backend is None or not:
|
||||
1. If not, check if the backend is supported by the platform.
|
||||
2. If None, continue to the default selection logic.
|
||||
"""
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention"
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
logger.info_once(
|
||||
f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
|
||||
)
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
device_id: int = 0,
|
||||
) -> DeviceCapability | None:
|
||||
"""Stateless version of [torch.cuda.get_device_capability][]."""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def has_device_capability(
|
||||
cls,
|
||||
capability: tuple[int, int] | int,
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Test whether this platform is compatible with a device capability.
|
||||
|
||||
The `capability` argument can either be:
|
||||
|
||||
- A tuple `(major, minor)`.
|
||||
- An integer `<major><minor>`. (See
|
||||
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
|
||||
"""
|
||||
current_capability = cls.get_device_capability(device_id=device_id)
|
||||
if current_capability is None:
|
||||
return False
|
||||
|
||||
if isinstance(capability, tuple):
|
||||
return current_capability >= capability
|
||||
|
||||
return current_capability.to_int() >= capability
|
||||
|
||||
@classmethod
|
||||
def is_device_capability(
|
||||
cls,
|
||||
capability: tuple[int, int] | int,
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Test whether this platform has exactly the specified device capability.
|
||||
|
||||
The `capability` argument can either be:
|
||||
|
||||
- A tuple `(major, minor)`.
|
||||
- An integer `<major><minor>`. (See
|
||||
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
|
||||
"""
|
||||
current_capability = cls.get_device_capability(device_id=device_id)
|
||||
if current_capability is None:
|
||||
return False
|
||||
|
||||
if isinstance(capability, tuple):
|
||||
return current_capability == capability
|
||||
|
||||
return current_capability.to_int() == capability
|
||||
|
||||
@classmethod
|
||||
def is_device_capability_family(
|
||||
cls,
|
||||
capability: int,
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the device capability is any <major>.x.
|
||||
Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x).
|
||||
"""
|
||||
current_capability = cls.get_device_capability(device_id=device_id)
|
||||
if current_capability is None:
|
||||
return False
|
||||
return (current_capability.to_int() // 10) == (capability // 10)
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
"""Get the name of a device."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_uuid(cls, device_id: int = 0) -> str:
|
||||
"""Get the uuid of a device, e.g. the PCI bus ID."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
"""Get the total memory of a device in bytes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
"""A device-specific wrapper of `torch.inference_mode`.
|
||||
|
||||
This wrapper is recommended because some hardware backends such as TPU
|
||||
do not support `torch.inference_mode`. In such a case, they will fall
|
||||
back to `torch.no_grad` by overriding this method.
|
||||
"""
|
||||
return torch.inference_mode(mode=True)
|
||||
|
||||
@classmethod
|
||||
def seed_everything(cls, seed: int | None = None) -> None:
|
||||
"""
|
||||
Set the seed of each random module.
|
||||
`torch.manual_seed` will set seed on all devices.
|
||||
|
||||
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
||||
"""
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def pre_register_and_update(
|
||||
cls, parser: FlexibleArgumentParser | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Do some pre-registration or update action for the current platform.
|
||||
|
||||
This function is called before global VllmConfig is initialized or cli
|
||||
arguments are parsed. It's used for out-of-tree platforms to register or
|
||||
update the configuration.
|
||||
|
||||
For example, the out-of-tree quantization config can be imported and
|
||||
registered here dynamically.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Check and update the configuration for the current platform.
|
||||
|
||||
It can raise an exception if the configuration is not compatible with
|
||||
the current platform, or it can update the configuration to make it
|
||||
compatible with the current platform.
|
||||
|
||||
The config is passed by reference, so it can be modified in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
"""
|
||||
Verify whether the current platform supports the specified model
|
||||
architecture.
|
||||
|
||||
- This will raise an Error or Warning based on the model support on
|
||||
the current platform.
|
||||
- By default all models are considered supported.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def verify_quantization(cls, quant: str) -> None:
|
||||
"""
|
||||
Verify whether the quantization is supported by the current platform.
|
||||
"""
|
||||
if cls.supported_quantization and quant not in cls.supported_quantization:
|
||||
raise ValueError(
|
||||
f"{quant} quantization is currently not supported in {cls.device_name}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cpu_architecture(cls) -> CpuArchEnum:
|
||||
"""
|
||||
Determine the CPU architecture of the current system.
|
||||
Returns CpuArchEnum indicating the architecture type.
|
||||
"""
|
||||
machine = platform.machine().lower()
|
||||
|
||||
if machine in ("x86_64", "amd64", "i386", "i686"):
|
||||
return CpuArchEnum.X86
|
||||
elif machine.startswith("arm") or machine.startswith("aarch"):
|
||||
return CpuArchEnum.ARM
|
||||
elif machine.startswith("ppc"):
|
||||
return CpuArchEnum.POWERPC
|
||||
elif machine == "s390x":
|
||||
return CpuArchEnum.S390X
|
||||
elif machine.startswith("riscv"):
|
||||
return CpuArchEnum.RISCV
|
||||
|
||||
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
"""Checks whether pin memory is available on the current platform."""
|
||||
if in_wsl():
|
||||
# Pinning memory in WSL is not supported.
|
||||
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
||||
logger.warning(
|
||||
"Using 'pin_memory=False' as WSL is detected. "
|
||||
"This may slow down the performance."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(
|
||||
cls, device: torch.types.Device | None = None
|
||||
) -> float:
|
||||
"""
|
||||
Return the memory usage in bytes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
"""
|
||||
Return the punica wrapper for current platform.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
|
||||
"""
|
||||
Return the platform specific values for (-inf, inf)
|
||||
"""
|
||||
return float("-inf"), float("inf")
|
||||
|
||||
@classmethod
|
||||
def can_update_inplace(cls) -> bool:
|
||||
"""
|
||||
Checks if the platform allows inplace memory updates
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lora_vocab_padding_size(cls) -> int:
|
||||
"""
|
||||
Returns how much padding the LoRA logits need for kernels
|
||||
"""
|
||||
return 256
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
"""
|
||||
Get device specific communicator class for distributed communication.
|
||||
"""
|
||||
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_mx(cls) -> bool:
|
||||
"""
|
||||
Returns whether the current platform supports MX types.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
"""
|
||||
Returns whether the current platform supports FP8 types.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
"""
|
||||
Returns whether the preferred FP8 type is FNUZ on the current platform.
|
||||
|
||||
There are two representations of FP8, OCP FP8 and FNUZ FP8.
|
||||
The OCP specification can be found at https://tinyurl.com/b7jvwpft.
|
||||
The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.
|
||||
|
||||
AMD's MI300 and MI325 have native hardware support for FNUZ. All other
|
||||
hardware has converged on the OCP FP8 standard.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
"""
|
||||
Returns the preferred FP8 type on the current platform.
|
||||
|
||||
See the documentation for is_fp8_fnuz for details.
|
||||
"""
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
@classmethod
|
||||
def use_all_gather(cls) -> bool:
|
||||
"""
|
||||
Whether to use allgather in LogitsProcessor to gather the logits.
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
"""
|
||||
Returns if custom allreduce is supported on the current platform
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
"""
|
||||
Returns True if we register attention as one giant opaque custom op
|
||||
on the current platform
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_request(
|
||||
cls,
|
||||
prompt: "PromptType",
|
||||
params: "SamplingParams | PoolingParams",
|
||||
processed_inputs: "ProcessorInputs",
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
device = getattr(torch, self.device_type, None)
|
||||
if device is not None and hasattr(device, key):
|
||||
return getattr(device, key)
|
||||
else:
|
||||
logger.warning(
|
||||
"Current platform %s does not have '%s' attribute.",
|
||||
self.device_type,
|
||||
key,
|
||||
)
|
||||
return None
|
||||
|
||||
def get_global_graph_pool(self) -> Any:
|
||||
"""
|
||||
Return the global graph pool for this platform.
|
||||
"""
|
||||
cls = self.__class__
|
||||
if cls._global_graph_pool is None:
|
||||
cls._global_graph_pool = self.graph_pool_handle()
|
||||
return cls._global_graph_pool
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
"""
|
||||
Get static graph wrapper class for static graph.
|
||||
"""
|
||||
return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: "PrefixStore",
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> "ProcessGroup":
|
||||
"""
|
||||
Init platform-specific torch distributed process group.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
"""
|
||||
Check if the dtype is supported by the current platform.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
"""
|
||||
Returns if the hybrid kv cache is supported by the current platform.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
"""
|
||||
Returns if the graph mode is supported by the current platform.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def use_sync_weight_loader(cls) -> bool:
|
||||
"""
|
||||
Returns if the current platform needs to sync weight loader.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def make_synced_weight_loader(cls, original_weight_loader):
|
||||
"""
|
||||
Wrap the original weight loader to make it synced.
|
||||
"""
|
||||
if not cls.use_sync_weight_loader():
|
||||
return original_weight_loader
|
||||
|
||||
def _synced_weight_loader(param, *args, **kwargs):
|
||||
out = original_weight_loader(param, *args, **kwargs)
|
||||
if param.device != torch.device("cpu"):
|
||||
torch._sync(param)
|
||||
return out
|
||||
|
||||
return _synced_weight_loader
|
||||
|
||||
@classmethod
|
||||
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
|
||||
"""
|
||||
Returns a mapping from device_type to a tuple of supported
|
||||
kv_buffer_device for nixl.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_nixl_memory_type(cls) -> str | None:
|
||||
"""
|
||||
Returns the nixl memory type for the current platform.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def check_max_model_len(cls, max_model_len: int) -> int:
|
||||
"""
|
||||
Check max_model_len for the current platform.
|
||||
"""
|
||||
return max_model_len
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
device_type = ""
|
||||
564
vllm/platforms/rocm.py
Normal file
564
vllm/platforms/rocm.py
Normal file
@@ -0,0 +1,564 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_gpu_asic_info,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from amdsmi with %r", e)
|
||||
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
try:
|
||||
import vllm._rocm_C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._rocm_C with %r", e)
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS: list[str] = []
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_SWA_REASON = ()
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
|
||||
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
||||
"0x74a0": "AMD_Instinct_MI300A",
|
||||
"0x74a1": "AMD_Instinct_MI300X",
|
||||
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF
|
||||
"0x74a2": "AMD_Instinct_MI308X",
|
||||
"0x74a5": "AMD_Instinct_MI325X",
|
||||
"0x74b9": "AMD_Instinct_MI325X", # MI325X VF
|
||||
"0x74a9": "AMD_Instinct_MI300X_HF",
|
||||
"0x74bd": "AMD_Instinct_MI300X_HF",
|
||||
"0x744c": "AMD_Radeon_RX7900XTX",
|
||||
}
|
||||
|
||||
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
|
||||
if "HIP_VISIBLE_DEVICES" in os.environ:
|
||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
|
||||
assert val == cuda_val
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = val
|
||||
|
||||
# AMDSMI utils
|
||||
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
# the major benefit of using AMDSMI is that it will not initialize CUDA
|
||||
|
||||
|
||||
def with_amdsmi_context(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
amdsmi_init()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx1x() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_mi3xx() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx9() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx950() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def use_rocm_custom_paged_attention(
|
||||
qtype: torch.dtype,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int,
|
||||
kv_cache_dtype: str,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> bool:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||
# disabled due to observed numerical discrepancy.
|
||||
if ON_GFX9:
|
||||
return (
|
||||
(sliding_window == 0 or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (rocm_aiter_ops.is_pa_attn_enabled())
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
else:
|
||||
return (
|
||||
ON_GFX11_GFX12
|
||||
and (sliding_window == 0 or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and head_size == 128
|
||||
and block_size == 16
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
device_name: str = "rocm"
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
ray_device_key: str = "GPU"
|
||||
dist_backend: str = "nccl"
|
||||
# rocm shares the same device control env var as CUDA
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"awq",
|
||||
"gptq",
|
||||
"fp8",
|
||||
"compressed-tensors",
|
||||
"fbgemm_fp8",
|
||||
"gguf",
|
||||
"quark",
|
||||
"ptpc_fp8",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
"torchao",
|
||||
]
|
||||
# bitsandbytes not supported on gfx9 (warp size 64 limitation)
|
||||
if not on_gfx9():
|
||||
supported_quantization += ["bitsandbytes"]
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
block_size = attn_selector_config.block_size
|
||||
kv_cache_dtype = attn_selector_config.kv_cache_dtype
|
||||
|
||||
if attn_selector_config.use_sparse:
|
||||
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
|
||||
raise ValueError(
|
||||
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
|
||||
)
|
||||
assert block_size == 1, (
|
||||
"Sparse MLA backend on ROCm only supports block size 1 for now."
|
||||
)
|
||||
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
|
||||
|
||||
if attn_selector_config.use_mla:
|
||||
if selected_backend is None:
|
||||
selected_backend = (
|
||||
AttentionBackendEnum.ROCM_AITER_MLA
|
||||
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
|
||||
else AttentionBackendEnum.TRITON_MLA
|
||||
)
|
||||
if selected_backend == AttentionBackendEnum.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
logger.info_once("Using Triton MLA backend.")
|
||||
return AttentionBackendEnum.TRITON_MLA.get_path()
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
)
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
|
||||
logger.info("Using AITER MLA backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
|
||||
logger.info("Using AITER TRITON MLA backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
|
||||
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend."
|
||||
)
|
||||
|
||||
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return AttentionBackendEnum.FLEX_ATTENTION.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_ATTN:
|
||||
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
if on_gfx9():
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The selected backend, {selected_backend.name}, "
|
||||
"is only supported on gfx9 architectures."
|
||||
)
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
|
||||
# Handle automatic backend selection based on environment variables
|
||||
if selected_backend is None:
|
||||
# Priority 1: Check for AITER Unified Attention (must check before MHA)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
|
||||
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
|
||||
# Priority 2: Check for AITER MHA (Flash Attention)
|
||||
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
|
||||
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
|
||||
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
|
||||
# Priority 4: Check for AITER enabled without specific flags
|
||||
# This defaults to AITER FA only if MHA is not explicitly disabled
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER
|
||||
and on_gfx9()
|
||||
and envs.VLLM_ROCM_USE_AITER_MHA is not False
|
||||
):
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Default: Triton Unified Attention
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
raise RuntimeError(
|
||||
f"Attention backend {selected_backend.name} is not supported on "
|
||||
"ROCm. Note that V0 attention backends have been removed."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention. "
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@classmethod
|
||||
@with_amdsmi_context
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
"""
|
||||
Query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@with_amdsmi_context
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = amdsmi_get_processor_handles()[physical_device_id]
|
||||
asic_info = amdsmi_get_gpu_asic_info(handle)
|
||||
device_name: str = asic_info["device_id"]
|
||||
if device_name in _ROCM_DEVICE_ID_NAME_MAP:
|
||||
return _ROCM_DEVICE_ID_NAME_MAP[device_name]
|
||||
return asic_info["market_name"]
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
|
||||
if compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# decode context parallel does not support full cudagraphs
|
||||
if parallel_config.decode_context_parallel_size > 1:
|
||||
logger.warning_once(
|
||||
"Decode context parallel (DCP) is enabled, which is "
|
||||
"incompatible with full CUDA graphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
# prefill context parallel do not support full cudagraphs
|
||||
elif parallel_config.prefill_context_parallel_size > 1:
|
||||
logger.warning_once(
|
||||
"Prefill context parallel (PCP) is enabled, which is "
|
||||
"incompatible with full CUDA graphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
|
||||
# NOTE: This block has been deprecated
|
||||
# or get_env_variable_attn_backend()
|
||||
# == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
|
||||
# TODO: monitor https://github.com/vllm-project/vllm/pull/30396
|
||||
# to see how we can transition to the new way of selecting
|
||||
# attention backends
|
||||
):
|
||||
cache_config.block_size = 64
|
||||
logger.warning(
|
||||
"[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
|
||||
)
|
||||
else:
|
||||
cache_config.block_size = 16
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
# Aiter rms norm perform best when CUDA Graph capture is enabled.
|
||||
if (
|
||||
use_aiter_rms_norm
|
||||
and not is_eager_execution
|
||||
and "-rms_norm" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
|
||||
compilation_config.custom_ops.append("+quant_fp8")
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model architecture '{model_arch}' is not supported by ROCm for now."
|
||||
)
|
||||
|
||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
|
||||
logger.warning(
|
||||
"Model architecture '%s' is partially supported by ROCm: %s",
|
||||
model_arch,
|
||||
msg,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_quantization(cls, quant: str) -> None:
|
||||
super().verify_quantization(quant)
|
||||
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
|
||||
logger.warning(
|
||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||
" is not set, enabling VLLM_USE_TRITON_AWQ."
|
||||
)
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "1"
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(
|
||||
cls, device: torch.types.Device | None = None
|
||||
) -> float:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return (
|
||||
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_mx(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
if cls.is_fp8_fnuz():
|
||||
return torch.float8_e4m3fnuz
|
||||
else:
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
@classmethod
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
# We only enable custom allreduce for MI300 series
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
supported_archs = ["gfx94", "gfx95"]
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_navi(cls) -> bool:
|
||||
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
if dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not cls.has_device_capability(80):
|
||||
capability = cls.get_device_capability()
|
||||
gpu_name = cls.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs "
|
||||
"with compute capability of at least 8.0. "
|
||||
f"Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
295
vllm/platforms/tpu.py
Normal file
295
vllm/platforms/tpu.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
import torch
|
||||
from tpu_info import device
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import BlockSize
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
ParamsType: TypeAlias = SamplingParams | PoolingParams
|
||||
else:
|
||||
BlockSize = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
ParamsType = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
USE_TPU_INFERENCE = False
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
device_name: str = "tpu"
|
||||
device_type: str = "tpu"
|
||||
dispatch_key: str = "XLA"
|
||||
ray_device_key: str = "TPU"
|
||||
dist_backend: str = "gloo"
|
||||
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
||||
simple_compile_backend: str = "openxla"
|
||||
|
||||
supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
|
||||
|
||||
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
# Do not import vllm._C
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
if attn_selector_config.use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return AttentionBackendEnum.PALLAS.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.PALLAS,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention"
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention.")
|
||||
return backend
|
||||
|
||||
logger.info_once(
|
||||
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
|
||||
)
|
||||
return AttentionBackendEnum.PALLAS
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.tpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
chip_type, _ = device.get_local_chips()
|
||||
return f"TPU {chip_type.name}"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
|
||||
return torch.finfo(dtype).min, torch.finfo(dtype).max
|
||||
|
||||
@classmethod
|
||||
def can_update_inplace(cls):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_lora_vocab_padding_size(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.config import CompilationMode, CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
# For v0, the default block size is 16.
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = cast(BlockSize, 16)
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
# TPU only supports DYNAMO_TRACE_ONCE compilation mode
|
||||
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
|
||||
logger.info(
|
||||
"[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
|
||||
disabling cudagraph."
|
||||
)
|
||||
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||
|
||||
if (
|
||||
compilation_config.cudagraph_mode is None
|
||||
or compilation_config.cudagraph_mode.max_cudagraph_mode()
|
||||
!= CUDAGraphMode.NONE
|
||||
):
|
||||
logger.info(
|
||||
"[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
if compilation_config.backend == "":
|
||||
compilation_config.backend = "openxla"
|
||||
|
||||
assert vllm_config.speculative_config is None, (
|
||||
"TPU does not support speculative decoding"
|
||||
)
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
if model_config is not None and model_config.dtype in (
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"The TPU backend currently does not support %s. "
|
||||
"Using bfloat16 instead.",
|
||||
model_config.dtype,
|
||||
)
|
||||
model_config.dtype = torch.bfloat16
|
||||
|
||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
||||
|
||||
cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment]
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend"
|
||||
)
|
||||
|
||||
if (
|
||||
scheduler_config.is_multimodal_model
|
||||
and not scheduler_config.disable_chunked_mm_input
|
||||
):
|
||||
logger.warning(
|
||||
"TPU does not support running Multimodal models"
|
||||
" without setting `--disable_chunked_mm_input`. "
|
||||
"Forcing --disable_chunked_mm_input."
|
||||
)
|
||||
scheduler_config.disable_chunked_mm_input = True
|
||||
|
||||
if model_config and model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled."
|
||||
)
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.model_config.max_model_len,
|
||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on TPU.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def validate_request(
|
||||
cls,
|
||||
prompt: PromptType,
|
||||
params: ParamsType,
|
||||
processed_inputs: ProcessorInputs,
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
|
||||
if (
|
||||
isinstance(params, SamplingParams)
|
||||
and params.sampling_type == SamplingType.RANDOM_SEED
|
||||
):
|
||||
raise ValueError("Torch XLA does not support per-request seed.")
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""tpu blocks to cpu blocks"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
||||
|
||||
@classmethod
|
||||
def use_sync_weight_loader(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_max_model_len(cls, max_model_len: int) -> int:
|
||||
"""
|
||||
Check max_model_len for the current platform.
|
||||
"""
|
||||
logger.warning(
|
||||
"--max-model-len is not specified, "
|
||||
"it's currently using model's default length %d, "
|
||||
"which might be too large."
|
||||
"Please input with --max-model-len based on your "
|
||||
"request input length and output length, to avoid "
|
||||
"unnecessary degradation.",
|
||||
max_model_len,
|
||||
)
|
||||
return max_model_len
|
||||
|
||||
|
||||
try:
|
||||
from tpu_inference.platforms import (
|
||||
TpuPlatform as TpuInferencePlatform,
|
||||
)
|
||||
|
||||
TpuPlatform = TpuInferencePlatform # type: ignore
|
||||
USE_TPU_INFERENCE = True
|
||||
except ImportError:
|
||||
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
|
||||
pass
|
||||
277
vllm/platforms/xpu.py
Normal file
277
vllm/platforms/xpu.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUPlatform(Platform):
|
||||
_enum = PlatformEnum.XPU
|
||||
device_name: str = "xpu"
|
||||
device_type: str = "xpu"
|
||||
dispatch_key: str = "XPU"
|
||||
# Intel XPU's device key is "GPU" for Ray.
|
||||
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
|
||||
ray_device_key: str = "GPU"
|
||||
dist_backend: str = "ccl" # ccl | xccl
|
||||
device_control_env_var: str = "ZE_AFFINITY_MASK"
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
# Do not import vllm._C
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout("NHD")
|
||||
logger.info(
|
||||
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
|
||||
"only NHD layout is supported by XPU attention kernels."
|
||||
)
|
||||
|
||||
if attn_selector_config.use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend.")
|
||||
return AttentionBackendEnum.FLASH_ATTN.get_path()
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}, "
|
||||
f"with use_mla: {attn_selector_config.use_mla}"
|
||||
)
|
||||
|
||||
logger.info("Using Flash Attention backend.")
|
||||
return AttentionBackendEnum.FLASH_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
# XPU only supports FLASH_ATTN for vision attention.
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention. "
|
||||
f"Supported backends are: "
|
||||
f"{cls.get_supported_vit_attn_backends()}."
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
logger.info_once(
|
||||
f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
|
||||
)
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
device_id: int = 0,
|
||||
) -> DeviceCapability | None:
|
||||
# capacity format differs from cuda's and will cause unexpected
|
||||
# failure, so use None directly
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.xpu.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
xpu_use_triton_kernel = os.getenv("XPU_USE_TRITON_KERNEL", "0") == "1"
|
||||
if not xpu_use_triton_kernel:
|
||||
return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
|
||||
else:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
# in V1(or with ipex chunked prefill) block_size is 64
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 64
|
||||
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CompilationMode, CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config.compile_sizes is None:
|
||||
compilation_config.compile_sizes = []
|
||||
|
||||
assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, (
|
||||
"CUDA graph mode should be NONE on XPU"
|
||||
)
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
vllm_config.kv_transfer_config.enable_permute_local_kv = True
|
||||
|
||||
if parallel_config.distributed_executor_backend is None:
|
||||
if parallel_config.world_size > 1:
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
else:
|
||||
parallel_config.distributed_executor_backend = "uni"
|
||||
elif parallel_config.distributed_executor_backend == "mp":
|
||||
# FIXME(kunshang):
|
||||
# spawn needs calling `if __name__ == '__main__':`
|
||||
# fork is not supported for xpu start new process.
|
||||
if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
logger.warning(
|
||||
"Please use spawn as start method if you want to use mp."
|
||||
)
|
||||
elif (
|
||||
parallel_config.distributed_executor_backend != "ray"
|
||||
and parallel_config.distributed_executor_backend != "uni"
|
||||
and parallel_config.distributed_executor_backend != "external_launcher"
|
||||
):
|
||||
logger.warning(
|
||||
"%s is not supported on XPU, fallback to ray distributed"
|
||||
" executor backend.",
|
||||
parallel_config.distributed_executor_backend,
|
||||
)
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
|
||||
if model_config and model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled."
|
||||
)
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.model_config.max_model_len,
|
||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(
|
||||
cls, device: torch.types.Device | None = None
|
||||
) -> float:
|
||||
torch.xpu.reset_peak_memory_stats(device)
|
||||
return torch.xpu.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
return torch.float8_e5m2
|
||||
|
||||
@classmethod
|
||||
def is_data_center_gpu(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
return device_name.count("data center gpu") > 0
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return torch.xpu.device_count()
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
if dtype == torch.bfloat16: # noqa: SIM102
|
||||
device_name = cls.get_device_name().lower()
|
||||
# client gpu a770
|
||||
if device_name.count("a770") > 0:
|
||||
raise ValueError(
|
||||
"Intel Arc A770 have bfloat16 accuracy known issue. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from src_cache to dst_cache on XPU."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from XPU to host (CPU)."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||
Reference in New Issue
Block a user