[gpt-oss] Add gpt-oss mxfp4 support
This commit is contained in:
@@ -37,6 +37,10 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [64, 128, 256]
|
||||
@@ -510,10 +514,10 @@ class FlashInferImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if use_irope:
|
||||
@@ -543,6 +547,19 @@ class FlashInferImpl(AttentionImpl):
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
|
||||
self.sinks: Optional[torch.Tensor] = None
|
||||
if sinks is not None:
|
||||
if sinks.shape[0] != num_heads:
|
||||
raise ValueError(
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Expected {num_heads}, but got "
|
||||
f"{sinks.shape[0]}."
|
||||
)
|
||||
# Cast sinks to float32 if needed (FlashInfer requirement)
|
||||
if sinks.dtype != torch.float32:
|
||||
sinks = sinks.to(torch.float32)
|
||||
self.sinks = sinks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -553,6 +570,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashInfer.
|
||||
|
||||
@@ -567,6 +585,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashInferImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user