Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -15,13 +15,11 @@ logger = init_logger(__name__)
_ROCM_FLASH_ATTN_AVAILABLE = False
if current_platform.is_cuda():
from vllm._custom_ops import reshape_and_cache_flash
# from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
# flash_attn_varlen_func,
# get_scheduler_metadata,
# )
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, flash_attn_varlen_int8_func
elif current_platform.is_xpu():
from vllm import _custom_ops as ops
from vllm._xpu_ops import xpu_ops
@@ -53,67 +51,93 @@ elif current_platform.is_rocm():
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
return 3
def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
if current_platform.is_xpu():
return 2
if current_platform.is_rocm():
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
return None
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason,
is_fa_version_supported,
)
return None
# try:
# from vllm.vllm_flash_attn.flash_attn_interface import (
# fa_version_unsupported_reason,
# is_fa_version_supported,
# )
device_capability = current_platform.get_device_capability()
# device_capability = current_platform.get_device_capability()
assert device_capability is not None
# assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
# # 1. default version depending on platform
# if device_capability.major == 9 and is_fa_version_supported(3):
# # Hopper (SM90): prefer FA3
# fa_version = 3
# elif device_capability.major == 10 and is_fa_version_supported(4):
# # Blackwell (SM100+, restrict to SM100 for now): prefer FA4
# fa_version = 4
# else:
# # Fallback to FA2
# fa_version = 2
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none
# # 2. override if passed by environment or config
# from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.flash_attn_version is not None
):
fa_version = vllm_config.attention_config.flash_attn_version
# vllm_config = get_current_vllm_config_or_none()
# if (
# vllm_config is not None
# and vllm_config.attention_config.flash_attn_version is not None
# ):
# fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2."
)
fa_version = 2
# # 3. fallback for unsupported combinations
# if device_capability.major >= 10 and fa_version == 3:
# logger.warning_once(
# "Cannot use FA version 3 on Blackwell platform, "
# "defaulting to FA version 4 if supported, otherwise FA2."
# )
# fa_version = 4 if is_fa_version_supported(4) else 2
if requires_alibi and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
# if requires_alibi and fa_version == 3:
# logger.warning_once(
# "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
# )
# fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
fa_version,
fa_version_unsupported_reason(fa_version),
)
# if requires_alibi and fa_version == 4:
# logger.warning_once(
# "Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
# )
# fa_version = 2
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
# # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# # supported head dimensions.
# # See: https://github.com/Dao-AILab/flash-attention/issues/1959
# if (
# fa_version == 4
# and device_capability.major >= 10
# and head_size is not None
# and head_size > 128
# ):
# logger.warning_once(
# "FA4 on Blackwell does not support head_size=%d due to TMEM "
# "capacity limits, defaulting to FA version 2.",
# head_size,
# )
# fa_version = 2
# if not is_fa_version_supported(fa_version):
# logger.error(
# "Cannot use FA version %d is not supported due to %s",
# fa_version,
# fa_version_unsupported_reason(fa_version),
# )
# assert is_fa_version_supported(fa_version)
# return fa_version
# except (ImportError, AssertionError):
# return None
def flash_attn_supports_fp8() -> bool:
@@ -124,10 +148,7 @@ def flash_attn_supports_fp8() -> bool:
def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu():
return True
else:
return get_flash_attn_version() == 3
return True
def flash_attn_supports_mla():
@@ -142,6 +163,10 @@ def flash_attn_supports_mla():
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError):
pass
return False