diff --git a/python/pyproject.toml b/python/pyproject.toml index 08b38d629..ad3a696ae 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -58,7 +58,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.3.4", + "sgl-kernel==0.3.4.post1", "torch==2.8.0", "torchaudio==2.8.0", "torchvision", diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 785cbf1d8..2d4e4b263 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend): # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, ): if k is not None: assert v is not None @@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend): forward_batch.forward_mode.is_target_verify() and self.topk > 1 ) + # For fa3 interface version compatibility, we put new fields into conditional keyword args + kwargs = {} + if sinks is not None: + kwargs["sinks"] = sinks + # Get the appropriate page table based on whether we're using local attention if use_local_attn: local_metadata = metadata.local_attn_metadata @@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, + **kwargs, ) if use_cascade_attn: @@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=True, + **kwargs, ) o, _ = merge_state_v2_wrapper( o, @@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend): # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if k is not None: assert v is not None @@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend): ) causal = not layer.is_cross_attention + # For fa3 interface version compatibility, we put new fields into conditional keyword args + kwargs = {} + if sinks is not None: + kwargs["sinks"] = sinks + k_descale, v_descale = None, None # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # has corresponding quantization method so that layer.k_scale is not None, @@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend): softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, + **kwargs, ) elif use_local_attn: # Use chunked (local) attention batching for self-attention @@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend): softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, + **kwargs, ) else: page_table = metadata.page_table @@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, + **kwargs, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=True, + **kwargs, ) ) o, _ = merge_state_v2( diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index e63f745ba..d04434600 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -294,7 +294,7 @@ class GptOssAttention(nn.Module): ) self.sinks = nn.Parameter( - torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False + torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False ) self.o_proj = RowParallelLinear( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0c76e7d1c..b7e053fd9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2106,10 +2106,10 @@ class ServerArgs: if model_arch in ["GptOssForCausalLM"]: if self.attention_backend is None: self.attention_backend = "triton" - assert self.attention_backend in [ - "triton", - "trtllm_mha", - ], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}" + supported_backends = ["triton", "trtllm_mha", "fa3"] + assert ( + self.attention_backend in supported_backends + ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'" quantization_config = getattr(hf_config, "quantization_config", None) is_mxfp4_quant_format = ( quantization_config is not None