[gpt-oss] Add gpt-oss mxfp4 support

This commit is contained in:
2025-08-25 15:31:09 +08:00
parent db7f48eeac
commit 7a35b2f32d
32 changed files with 4835 additions and 1190 deletions

View File

@@ -179,12 +179,14 @@ class CudaPlatformBase(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
if use_mla:
# TODO(lucas): refactor to be more concise
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
if selected_backend == _Backend.CUTLASS_MLA or (
cls.is_device_capability(100) and selected_backend is None
and block_size == 128):
if use_v1:
logger.info_once("Using Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
@@ -223,37 +225,101 @@ class CudaPlatformBase(Platform):
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
if selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttenion backend on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
if cls.has_device_capability(100):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
set_kv_cache_layout("HND")
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
return TRITON_ATTN_VLLM_V1
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1
# elif selected_backend == _Backend.TREE_ATTN:
# logger.info_once("Using Tree Attention backend on V1 engine.")
# return TREE_ATTN_V1
# elif selected_backend == _Backend.XFORMERS_VLLM_V1:
# logger.info_once("Using XFormers backend on V1 engine.")
# return XFORMERS_V1
from vllm.attention.selector import is_attn_backend_supported
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
try:
import flashinfer # noqa: F401
logger.info_once(
"Using FlashInfer backend on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.")
return ("vllm.v1.attention.backends."
"flashinfer.FlashInferBackend")
except ImportError:
if is_default_backend_supported := is_attn_backend_supported(
FLASHINFER_V1, head_size, dtype):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
set_kv_cache_layout("HND")
return FLASHINFER_V1
if not is_default_backend_supported.can_import:
logger.warning_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
if has_sink and not cls.is_device_capability(90):
logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1
if is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False):
logger.info_once("Using Flash Attention backend on "
"V1 engine.")
return FLASH_ATTN_V1
# FlexAttention is the default for older GPUs
else:
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
assert not is_default_backend_supported
use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype
logger.info_once(
"Using FlexAttention backend for %s on V1 engine.",
", ".join(f"{k}={v}"
for k, v in use_flex_attention_reason.items()),
)
return FLEX_ATTENTION_V1
# Backends for V0 engine
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
if cls.has_device_capability(100):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
logger.info_once(
"Using HND KV cache layout on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.")
set_kv_cache_layout("HND")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
elif selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
@@ -262,6 +328,10 @@ class CudaPlatformBase(Platform):
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
logger.info("Using DifferentialFlashAttention backend.")
return ("vllm.attention.backends.differential_flash_attn."
"DifferentialFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
@@ -291,7 +361,7 @@ class CudaPlatformBase(Platform):
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import flash_attn # noqa: F401
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend, flash_attn_supports_fp8)