[gpt-oss] Add gpt-oss mxfp4 support
This commit is contained in:
@@ -179,12 +179,14 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
|
||||
if selected_backend == _Backend.CUTLASS_MLA or (
|
||||
cls.is_device_capability(100) and selected_backend is None
|
||||
and block_size == 128):
|
||||
if use_v1:
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
@@ -223,37 +225,101 @@ class CudaPlatformBase(Platform):
|
||||
return ("vllm.attention.backends."
|
||||
"flashmla.FlashMLABackend")
|
||||
if use_v1:
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
|
||||
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
if selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info("Using FlexAttenion backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||
if cls.has_device_capability(100):
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
set_kv_cache_layout("HND")
|
||||
return FLASHINFER_V1
|
||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
# elif selected_backend == _Backend.TREE_ATTN:
|
||||
# logger.info_once("Using Tree Attention backend on V1 engine.")
|
||||
# return TREE_ATTN_V1
|
||||
# elif selected_backend == _Backend.XFORMERS_VLLM_V1:
|
||||
# logger.info_once("Using XFormers backend on V1 engine.")
|
||||
# return XFORMERS_V1
|
||||
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
|
||||
# Default backends for V1 engine
|
||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||
if cls.is_device_capability(100):
|
||||
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend on V1 engine by default for "
|
||||
"Blackwell (SM 10.0) GPUs.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"flashinfer.FlashInferBackend")
|
||||
except ImportError:
|
||||
if is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASHINFER_V1, head_size, dtype):
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend with HND KV cache layout on "
|
||||
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
return FLASHINFER_V1
|
||||
|
||||
if not is_default_backend_supported.can_import:
|
||||
logger.warning_once(
|
||||
"FlashInfer failed to import for V1 engine on "
|
||||
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
||||
"install FlashInfer for better performance.")
|
||||
pass
|
||||
|
||||
# FlashAttention is the default for SM 8.0+ GPUs
|
||||
if cls.has_device_capability(80):
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"flash_attn.FlashAttentionBackend")
|
||||
if has_sink and not cls.is_device_capability(90):
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
if is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype,
|
||||
allow_import_error=False):
|
||||
logger.info_once("Using Flash Attention backend on "
|
||||
"V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
|
||||
# FlexAttention is the default for older GPUs
|
||||
else:
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
assert not is_default_backend_supported
|
||||
|
||||
use_flex_attention_reason = {}
|
||||
if not is_default_backend_supported.head_size:
|
||||
use_flex_attention_reason["head_size"] = head_size
|
||||
if not is_default_backend_supported.dtype:
|
||||
use_flex_attention_reason["dtype"] = dtype
|
||||
|
||||
logger.info_once(
|
||||
"Using FlexAttention backend for %s on V1 engine.",
|
||||
", ".join(f"{k}={v}"
|
||||
for k, v in use_flex_attention_reason.items()),
|
||||
)
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
# Backends for V0 engine
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info("Using FlashInfer backend.")
|
||||
if cls.has_device_capability(100):
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
logger.info_once(
|
||||
"Using HND KV cache layout on V1 engine by default for "
|
||||
"Blackwell (SM 10.0) GPUs.")
|
||||
set_kv_cache_layout("HND")
|
||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
@@ -262,6 +328,10 @@ class CudaPlatformBase(Platform):
|
||||
logger.info("Using DualChunkFlashAttention backend.")
|
||||
return ("vllm.attention.backends.dual_chunk_flash_attn."
|
||||
"DualChunkFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
|
||||
logger.info("Using DifferentialFlashAttention backend.")
|
||||
return ("vllm.attention.backends.differential_flash_attn."
|
||||
"DifferentialFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
pass
|
||||
elif selected_backend:
|
||||
@@ -291,7 +361,7 @@ class CudaPlatformBase(Platform):
|
||||
# installed.
|
||||
if target_backend == _Backend.FLASH_ATTN:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
import vllm.vllm_flash_attn # noqa: F401
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend, flash_attn_supports_fp8)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user