Support FA3 backend for gpt-oss (#9028)
This commit is contained in:
@@ -58,7 +58,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.3.4",
|
"sgl-kernel==0.3.4.post1",
|
||||||
"torch==2.8.0",
|
"torch==2.8.0",
|
||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
|||||||
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# For multi-head latent attention
|
# For multi-head latent attention
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
sinks: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||||
|
kwargs = {}
|
||||||
|
if sinks is not None:
|
||||||
|
kwargs["sinks"] = sinks
|
||||||
|
|
||||||
# Get the appropriate page table based on whether we're using local attention
|
# Get the appropriate page table based on whether we're using local attention
|
||||||
if use_local_attn:
|
if use_local_attn:
|
||||||
local_metadata = metadata.local_attn_metadata
|
local_metadata = metadata.local_attn_metadata
|
||||||
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=use_cascade_attn,
|
return_softmax_lse=use_cascade_attn,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_cascade_attn:
|
if use_cascade_attn:
|
||||||
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
o, _ = merge_state_v2_wrapper(
|
o, _ = merge_state_v2_wrapper(
|
||||||
o,
|
o,
|
||||||
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# For multi-head latent attention
|
# For multi-head latent attention
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
sinks: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
causal = not layer.is_cross_attention
|
causal = not layer.is_cross_attention
|
||||||
|
|
||||||
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||||
|
kwargs = {}
|
||||||
|
if sinks is not None:
|
||||||
|
kwargs["sinks"] = sinks
|
||||||
|
|
||||||
k_descale, v_descale = None, None
|
k_descale, v_descale = None, None
|
||||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||||
# has corresponding quantization method so that layer.k_scale is not None,
|
# has corresponding quantization method so that layer.k_scale is not None,
|
||||||
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif use_local_attn:
|
elif use_local_attn:
|
||||||
# Use chunked (local) attention batching for self-attention
|
# Use chunked (local) attention batching for self-attention
|
||||||
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
page_table = metadata.page_table
|
page_table = metadata.page_table
|
||||||
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=use_cascade_attn,
|
return_softmax_lse=use_cascade_attn,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
if use_cascade_attn:
|
if use_cascade_attn:
|
||||||
o, softmax_lse, *rest = result
|
o, softmax_lse, *rest = result
|
||||||
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
o, _ = merge_state_v2(
|
o, _ = merge_state_v2(
|
||||||
|
|||||||
@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.sinks = nn.Parameter(
|
self.sinks = nn.Parameter(
|
||||||
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
|
torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
|
|||||||
@@ -2106,10 +2106,10 @@ class ServerArgs:
|
|||||||
if model_arch in ["GptOssForCausalLM"]:
|
if model_arch in ["GptOssForCausalLM"]:
|
||||||
if self.attention_backend is None:
|
if self.attention_backend is None:
|
||||||
self.attention_backend = "triton"
|
self.attention_backend = "triton"
|
||||||
assert self.attention_backend in [
|
supported_backends = ["triton", "trtllm_mha", "fa3"]
|
||||||
"triton",
|
assert (
|
||||||
"trtllm_mha",
|
self.attention_backend in supported_backends
|
||||||
], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}"
|
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
|
||||||
quantization_config = getattr(hf_config, "quantization_config", None)
|
quantization_config = getattr(hf_config, "quantization_config", None)
|
||||||
is_mxfp4_quant_format = (
|
is_mxfp4_quant_format = (
|
||||||
quantization_config is not None
|
quantization_config is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user