56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Optional
|
|
|
|
from vllm import envs
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
|
# import here to avoid circular dependencies
|
|
from vllm.platforms import current_platform
|
|
try:
|
|
from vllm.vllm_flash_attn.flash_attn_interface import (
|
|
fa_version_unsupported_reason, is_fa_version_supported)
|
|
device_capability = current_platform.get_device_capability()
|
|
|
|
assert device_capability is not None
|
|
|
|
# 1. default version depending on platform
|
|
fa_version = 3 if (device_capability.major == 9
|
|
and is_fa_version_supported(3)) else 2
|
|
|
|
# 2. override if passed by environment
|
|
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
|
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
|
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
|
|
|
# 3. fallback for unsupported combinations
|
|
if device_capability.major == 10 and fa_version == 3:
|
|
logger.warning_once(
|
|
"Cannot use FA version 3 on Blackwell platform "
|
|
"defaulting to FA version 2.")
|
|
fa_version = 2
|
|
|
|
if requires_alibi and fa_version == 3:
|
|
logger.warning_once("Cannot use FA version 3 with ALiBi, "
|
|
"defaulting to FA version 2.")
|
|
fa_version = 2
|
|
|
|
if not is_fa_version_supported(fa_version):
|
|
logger.error("Cannot use FA version %d is not supported due to %s",
|
|
fa_version, fa_version_unsupported_reason(fa_version))
|
|
|
|
assert is_fa_version_supported(fa_version)
|
|
return fa_version
|
|
except (ImportError, AssertionError):
|
|
return None
|
|
|
|
|
|
def flash_attn_supports_fp8() -> bool:
|
|
from vllm.platforms import current_platform
|
|
return get_flash_attn_version() == 3 and \
|
|
current_platform.get_device_capability().major == 9
|