[gpt-oss] Add gpt-oss mxfp4 support

This commit is contained in:
2025-08-25 17:41:34 +08:00
parent db7f48eeac
commit ce688181e6
33 changed files with 4835 additions and 1192 deletions

View File

@@ -37,6 +37,10 @@ class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [64, 128, 256]
@@ -510,10 +514,10 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
use_irope: bool = False,
) -> None:
if use_irope:
@@ -543,6 +547,19 @@ class FlashInferImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
if sinks.shape[0] != num_heads:
raise ValueError(
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}."
)
# Cast sinks to float32 if needed (FlashInfer requirement)
if sinks.dtype != torch.float32:
sinks = sinks.to(torch.float32)
self.sinks = sinks
def forward(
self,
@@ -553,6 +570,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashInfer.
@@ -567,6 +585,11 @@ class FlashInferImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
if attn_metadata is None:
# Profiling run.
return output