[gpt-oss] Add gpt-oss mxfp4 support

This commit is contained in:
2025-08-25 17:41:34 +08:00
parent db7f48eeac
commit ce688181e6
33 changed files with 4835 additions and 1192 deletions

View File

@@ -44,10 +44,25 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_VLLM_V1"
@@ -657,15 +672,12 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
sinks: Optional[torch.Tensor] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -675,6 +687,8 @@ class FlashAttentionImpl(AttentionImpl):
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type == AttentionType.ENCODER_ONLY:
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
@@ -684,28 +698,33 @@ class FlashAttentionImpl(AttentionImpl):
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
FlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl")
self.use_irope = use_irope
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
self.sinks = sinks
if self.sinks is not None:
assert self.vllm_flash_attn_version == 3, (
"Sinks are only supported in FlashAttention 3")
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer")
def forward(
self,
layer: torch.nn.Module,
@@ -715,6 +734,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@@ -732,6 +752,11 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output