Files
bi_150-vllm/vllm/v1/attention/backends/fa_utils.py

203 lines
7.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
# Track whether upstream flash-attn is available on ROCm.
# Set during module initialization and never modified afterwards.
# This module-level flag avoids repeated import attempts and ensures
# consistent behavior (similar to IS_AITER_FOUND in _aiter_ops.py).
_ROCM_FLASH_ATTN_AVAILABLE = False
if current_platform.is_cuda():
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, flash_attn_varlen_int8_func
elif current_platform.is_xpu():
from vllm import _custom_ops as ops
from vllm._xpu_ops import xpu_ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[assignment]
get_scheduler_metadata = xpu_ops.get_scheduler_metadata # type: ignore[assignment]
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
# Mark that upstream flash-attn is available on ROCm
_ROCM_FLASH_ATTN_AVAILABLE = True
except ImportError:
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
raise ImportError(
"ROCm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
)
# ROCm doesn't use scheduler metadata (FA3 feature), provide stub
def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
return None
# ROCm uses the C++ custom op for reshape_and_cache
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
if current_platform.is_xpu():
return 2
if current_platform.is_rocm():
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
return None
return None
# try:
# from vllm.vllm_flash_attn.flash_attn_interface import (
# fa_version_unsupported_reason,
# is_fa_version_supported,
# )
# device_capability = current_platform.get_device_capability()
# assert device_capability is not None
# # 1. default version depending on platform
# if device_capability.major == 9 and is_fa_version_supported(3):
# # Hopper (SM90): prefer FA3
# fa_version = 3
# elif device_capability.major == 10 and is_fa_version_supported(4):
# # Blackwell (SM100+, restrict to SM100 for now): prefer FA4
# fa_version = 4
# else:
# # Fallback to FA2
# fa_version = 2
# # 2. override if passed by environment or config
# from vllm.config import get_current_vllm_config_or_none
# vllm_config = get_current_vllm_config_or_none()
# if (
# vllm_config is not None
# and vllm_config.attention_config.flash_attn_version is not None
# ):
# fa_version = vllm_config.attention_config.flash_attn_version
# # 3. fallback for unsupported combinations
# if device_capability.major >= 10 and fa_version == 3:
# logger.warning_once(
# "Cannot use FA version 3 on Blackwell platform, "
# "defaulting to FA version 4 if supported, otherwise FA2."
# )
# fa_version = 4 if is_fa_version_supported(4) else 2
# if requires_alibi and fa_version == 3:
# logger.warning_once(
# "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
# )
# fa_version = 2
# if requires_alibi and fa_version == 4:
# logger.warning_once(
# "Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
# )
# fa_version = 2
# # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# # supported head dimensions.
# # See: https://github.com/Dao-AILab/flash-attention/issues/1959
# if (
# fa_version == 4
# and device_capability.major >= 10
# and head_size is not None
# and head_size > 128
# ):
# logger.warning_once(
# "FA4 on Blackwell does not support head_size=%d due to TMEM "
# "capacity limits, defaulting to FA version 2.",
# head_size,
# )
# fa_version = 2
# if not is_fa_version_supported(fa_version):
# logger.error(
# "Cannot use FA version %d is not supported due to %s",
# fa_version,
# fa_version_unsupported_reason(fa_version),
# )
# assert is_fa_version_supported(fa_version)
# return fa_version
# except (ImportError, AssertionError):
# return None
def flash_attn_supports_fp8() -> bool:
return (
get_flash_attn_version() == 3
and current_platform.is_device_capability_family(90)
)
def flash_attn_supports_sinks() -> bool:
return True
def flash_attn_supports_mla():
from vllm.platforms import current_platform
if current_platform.is_cuda():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
is_fa_version_supported,
)
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError):
pass
return False
def is_flash_attn_varlen_func_available() -> bool:
"""Check if flash_attn_varlen_func is available.
This function determines whether the flash_attn_varlen_func imported at module
level is a working implementation or a stub.
Platform-specific sources:
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
- XPU: xpu_ops.flash_attn_varlen_func
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
which uses rocm_aiter_ops.flash_attn_varlen_func. The condition to use AITER is
handled separately via _aiter_ops.is_aiter_found_and_supported().
Returns:
bool: True if a working flash_attn_varlen_func implementation is available.
"""
if current_platform.is_cuda() or current_platform.is_xpu():
# CUDA and XPU always have flash_attn_varlen_func available
return True
if current_platform.is_rocm():
# Use the flag set during module import to check if
# upstream flash-attn was successfully imported
return _ROCM_FLASH_ATTN_AVAILABLE
return False