Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -15,13 +15,11 @@ logger = init_logger(__name__)
|
||||
_ROCM_FLASH_ATTN_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm._custom_ops import reshape_and_cache_flash
|
||||
# from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
# flash_attn_varlen_func,
|
||||
# get_scheduler_metadata,
|
||||
# )
|
||||
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, flash_attn_varlen_int8_func
|
||||
|
||||
elif current_platform.is_xpu():
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._xpu_ops import xpu_ops
|
||||
@@ -53,67 +51,93 @@ elif current_platform.is_rocm():
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return 3
|
||||
|
||||
def get_flash_attn_version(
|
||||
requires_alibi: bool = False, head_size: int | None = None
|
||||
) -> int | None:
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
if current_platform.is_rocm():
|
||||
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
|
||||
return None
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
return None
|
||||
# try:
|
||||
# from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
# fa_version_unsupported_reason,
|
||||
# is_fa_version_supported,
|
||||
# )
|
||||
|
||||
device_capability = current_platform.get_device_capability()
|
||||
# device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
# assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = (
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
# # 1. default version depending on platform
|
||||
# if device_capability.major == 9 and is_fa_version_supported(3):
|
||||
# # Hopper (SM90): prefer FA3
|
||||
# fa_version = 3
|
||||
# elif device_capability.major == 10 and is_fa_version_supported(4):
|
||||
# # Blackwell (SM100+, restrict to SM100 for now): prefer FA4
|
||||
# fa_version = 4
|
||||
# else:
|
||||
# # Fallback to FA2
|
||||
# fa_version = 2
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
# # 2. override if passed by environment or config
|
||||
# from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.flash_attn_version is not None
|
||||
):
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
# vllm_config = get_current_vllm_config_or_none()
|
||||
# if (
|
||||
# vllm_config is not None
|
||||
# and vllm_config.attention_config.flash_attn_version is not None
|
||||
# ):
|
||||
# fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform, "
|
||||
"defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
# # 3. fallback for unsupported combinations
|
||||
# if device_capability.major >= 10 and fa_version == 3:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 3 on Blackwell platform, "
|
||||
# "defaulting to FA version 4 if supported, otherwise FA2."
|
||||
# )
|
||||
# fa_version = 4 if is_fa_version_supported(4) else 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
# if requires_alibi and fa_version == 3:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error(
|
||||
"Cannot use FA version %d is not supported due to %s",
|
||||
fa_version,
|
||||
fa_version_unsupported_reason(fa_version),
|
||||
)
|
||||
# if requires_alibi and fa_version == 4:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
# # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
|
||||
# # supported head dimensions.
|
||||
# # See: https://github.com/Dao-AILab/flash-attention/issues/1959
|
||||
# if (
|
||||
# fa_version == 4
|
||||
# and device_capability.major >= 10
|
||||
# and head_size is not None
|
||||
# and head_size > 128
|
||||
# ):
|
||||
# logger.warning_once(
|
||||
# "FA4 on Blackwell does not support head_size=%d due to TMEM "
|
||||
# "capacity limits, defaulting to FA version 2.",
|
||||
# head_size,
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
# if not is_fa_version_supported(fa_version):
|
||||
# logger.error(
|
||||
# "Cannot use FA version %d is not supported due to %s",
|
||||
# fa_version,
|
||||
# fa_version_unsupported_reason(fa_version),
|
||||
# )
|
||||
|
||||
# assert is_fa_version_supported(fa_version)
|
||||
# return fa_version
|
||||
# except (ImportError, AssertionError):
|
||||
# return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
@@ -124,10 +148,7 @@ def flash_attn_supports_fp8() -> bool:
|
||||
|
||||
|
||||
def flash_attn_supports_sinks() -> bool:
|
||||
if current_platform.is_xpu():
|
||||
return True
|
||||
else:
|
||||
return get_flash_attn_version() == 3
|
||||
return True
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
@@ -142,6 +163,10 @@ def flash_attn_supports_mla():
|
||||
return is_fa_version_supported(
|
||||
3
|
||||
) and current_platform.is_device_capability_family(90)
|
||||
|
||||
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
|
||||
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user