Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -1,94 +1,145 @@
import enum
from functools import lru_cache
from typing import Type
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cache
from typing import NamedTuple, cast, get_args
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip, is_musa
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"attn_type={self.attn_type})"
)
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
backend = _which_attn_to_use(dtype)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
elif backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
return XFormersBackend
elif backend == _Backend.ROCM_FLASH:
logger.info("Using ROCmFlashAttention backend.")
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
if kv_cache_dtype is not None:
valid_cache_dtypes = get_args(CacheDType)
assert kv_cache_dtype in valid_cache_dtypes, (
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
f"Valid values are: {valid_cache_dtypes}"
)
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
block_size=block_size,
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type or AttentionType.DECODER,
)
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config=attn_selector_config,
)
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"""Returns which flash attention backend to use."""
if is_cpu():
return _Backend.TORCH_SDPA
if is_musa():
return _Backend.FLASH_ATTN
@cache
def _cached_get_attn_backend(
backend,
attn_selector_config: AttentionSelectorConfig,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
if is_hip():
# AMD GPUs.
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs.")
return _Backend.ROCM_FLASH
attention_cls = current_platform.get_attn_backend_cls(
backend,
attn_selector_config=attn_selector_config,
)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
backend = resolve_obj_by_qualname(attention_cls)
# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
return _Backend.XFORMERS
# Adjust kv cache layout if the selected backend requires a specific one
required_layout = backend.get_required_kv_cache_layout()
if required_layout is not None:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS
try:
import flash_attn # noqa: F401
except ImportError:
set_kv_cache_layout(required_layout)
logger.info(
"Cannot use FlashAttention-2 backend because the flash_attn "
"package is not found. Please install it for better performance.")
return _Backend.XFORMERS
"Using %s KV cache layout for %s backend.",
required_layout,
backend.get_name(),
)
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]
return backend
# Default case.
return _Backend.FLASH_ATTN
def get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
"""Select which mamba attention backend to use and lazily import it."""
return _cached_get_mamba_attn_backend(mamba_type)
@cache
def _cached_get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
assert mamba_type and isinstance(mamba_type, str)
selected_backend = None
try:
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
selected_backend = MambaAttentionBackendEnum[backend_name]
except KeyError as e:
raise ValueError(
f"Invalid mamba attention backend type: '{backend_name}'. Valid "
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
) from e
mamba_attn_backend = selected_backend.get_class()
return mamba_attn_backend