Support FA3 backend for gpt-oss (#9028)
This commit is contained in:
@@ -58,7 +58,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.4",
|
||||
"sgl-kernel==0.3.4.post1",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
|
||||
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
||||
)
|
||||
|
||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||
kwargs = {}
|
||||
if sinks is not None:
|
||||
kwargs["sinks"] = sinks
|
||||
|
||||
# Get the appropriate page table based on whether we're using local attention
|
||||
if use_local_attn:
|
||||
local_metadata = metadata.local_attn_metadata
|
||||
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cascade_attn:
|
||||
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
**kwargs,
|
||||
)
|
||||
o, _ = merge_state_v2_wrapper(
|
||||
o,
|
||||
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
causal = not layer.is_cross_attention
|
||||
|
||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||
kwargs = {}
|
||||
if sinks is not None:
|
||||
kwargs["sinks"] = sinks
|
||||
|
||||
k_descale, v_descale = None, None
|
||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||
# has corresponding quantization method so that layer.k_scale is not None,
|
||||
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
**kwargs,
|
||||
)
|
||||
elif use_local_attn:
|
||||
# Use chunked (local) attention batching for self-attention
|
||||
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
page_table = metadata.page_table
|
||||
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn,
|
||||
**kwargs,
|
||||
)
|
||||
if use_cascade_attn:
|
||||
o, softmax_lse, *rest = result
|
||||
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
o, _ = merge_state_v2(
|
||||
|
||||
@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.sinks = nn.Parameter(
|
||||
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
|
||||
torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
|
||||
@@ -2106,10 +2106,10 @@ class ServerArgs:
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "triton"
|
||||
assert self.attention_backend in [
|
||||
"triton",
|
||||
"trtllm_mha",
|
||||
], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}"
|
||||
supported_backends = ["triton", "trtllm_mha", "fa3"]
|
||||
assert (
|
||||
self.attention_backend in supported_backends
|
||||
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
|
||||
quantization_config = getattr(hf_config, "quantization_config", None)
|
||||
is_mxfp4_quant_format = (
|
||||
quantization_config is not None
|
||||
|
||||
Reference in New Issue
Block a user