Support FA3 backend for gpt-oss (#9028)

This commit is contained in:
Ke Bao
2025-08-14 01:41:50 +08:00
committed by GitHub
parent 4a16a71c36
commit 0ff6d1fce1
4 changed files with 24 additions and 6 deletions

View File

@@ -58,7 +58,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.4",
"sgl-kernel==0.3.4.post1",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",

View File

@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
if k is not None:
assert v is not None
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.forward_mode.is_target_verify() and self.topk > 1
)
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if sinks is not None:
kwargs["sinks"] = sinks
# Get the appropriate page table based on whether we're using local attention
if use_local_attn:
local_metadata = metadata.local_attn_metadata
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
**kwargs,
)
if use_cascade_attn:
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
o,
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
)
causal = not layer.is_cross_attention
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if sinks is not None:
kwargs["sinks"] = sinks
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
**kwargs,
)
elif use_local_attn:
# Use chunked (local) attention batching for self-attention
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
**kwargs,
)
else:
page_table = metadata.page_table
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
**kwargs,
)
)
o, _ = merge_state_v2(

View File

@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
)
self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
)
self.o_proj = RowParallelLinear(

View File

@@ -2106,10 +2106,10 @@ class ServerArgs:
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
self.attention_backend = "triton"
assert self.attention_backend in [
"triton",
"trtllm_mha",
], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}"
supported_backends = ["triton", "trtllm_mha", "fa3"]
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None