[gpt-oss] Add gpt-oss mxfp4 support

This commit is contained in:
2025-08-25 17:41:34 +08:00
parent db7f48eeac
commit ce688181e6
33 changed files with 4835 additions and 1192 deletions

View File

@@ -44,10 +44,25 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_VLLM_V1"
@@ -657,15 +672,12 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
sinks: Optional[torch.Tensor] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -675,6 +687,8 @@ class FlashAttentionImpl(AttentionImpl):
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type == AttentionType.ENCODER_ONLY:
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
@@ -684,28 +698,33 @@ class FlashAttentionImpl(AttentionImpl):
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
FlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl")
self.use_irope = use_irope
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
self.sinks = sinks
if self.sinks is not None:
assert self.vllm_flash_attn_version == 3, (
"Sinks are only supported in FlashAttention 3")
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer")
def forward(
self,
layer: torch.nn.Module,
@@ -715,6 +734,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@@ -732,6 +752,11 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output

View File

@@ -37,6 +37,10 @@ class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [64, 128, 256]
@@ -510,10 +514,10 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
use_irope: bool = False,
) -> None:
if use_irope:
@@ -543,6 +547,19 @@ class FlashInferImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
if sinks.shape[0] != num_heads:
raise ValueError(
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}."
)
# Cast sinks to float32 if needed (FlashInfer requirement)
if sinks.dtype != torch.float32:
sinks = sinks.to(torch.float32)
self.sinks = sinks
def forward(
self,
@@ -553,6 +570,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashInfer.
@@ -567,6 +585,11 @@ class FlashInferImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
if attn_metadata is None:
# Profiling run.
return output

View File

@@ -44,6 +44,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [16, 32, 64, 96, 128, 160, 192, 224, 256]
@@ -346,15 +350,10 @@ class FlexAttentionImpl(AttentionImpl):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if blocksparse_params is not None:
# TODO we should support this :think
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -410,6 +409,7 @@ class FlexAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FLexAttention.
@@ -423,6 +423,11 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlexAttentionImpl")
enable_gqa = self.num_kv_heads != self.num_heads
if attn_metadata is None:

View File

@@ -110,7 +110,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
@@ -120,9 +119,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -139,8 +135,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("Alibi slopes is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
@@ -161,6 +155,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
@@ -173,6 +168,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl")
# For determine_available_memory case.
if kv_cache.numel() == 0:
if output is None:

View File

@@ -3,6 +3,10 @@
"""Attention layer with PagedAttention and Triton prefix prefill."""
from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass
from functools import cache
from typing import ClassVar, Optional
import torch
from vllm import _custom_ops as ops
@@ -12,7 +16,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
@@ -38,10 +42,25 @@ class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TRITON_ATTN_VLLM_V1"
@@ -74,6 +93,15 @@ class TritonAttentionBackend(AttentionBackend):
return TritonAttentionMetadataBuilder
@cache
def use_aiter_unified_attention() -> bool:
"""Check if aiter unified attention should be used."""
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
# to 1 as default
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
class TritonAttentionImpl(AttentionImpl):
def __init__(
@@ -85,16 +113,11 @@ class TritonAttentionImpl(AttentionImpl):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"TritonAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -113,16 +136,9 @@ class TritonAttentionImpl(AttentionImpl):
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.use_irope = use_irope
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention. "
f"Supported head sizes are: {support_head_sizes}.")
TritonAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
@@ -133,7 +149,23 @@ class TritonAttentionImpl(AttentionImpl):
self.fp8_dtype = current_platform.fp8_dtype()
self.force_prefill_decode_attn = \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
if not self.force_prefill_decode_attn:
# If not using prefill decode attention, we use the Triton
# unified attention implementation.
if use_aiter_unified_attention():
logger.info_once(
"Using aiter unified attention for TritonAttentionImpl")
from aiter.ops.triton.unified_attention import (
unified_attention)
self.unified_attention = unified_attention
else:
logger.info_once(
"Using vllm unified attention for TritonAttentionImpl")
from vllm.attention.ops.triton_unified_attention import (
unified_attention)
self.unified_attention = unified_attention
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
@@ -150,6 +182,7 @@ class TritonAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@@ -164,6 +197,11 @@ class TritonAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TritonAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
@@ -227,51 +265,41 @@ class TritonAttentionImpl(AttentionImpl):
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
sinks=self.sinks)
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
sinks=self.sinks,
)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
unified_attention(
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,