diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index d04434600..e15bc5dc2 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -293,8 +293,12 @@ class GptOssAttention(nn.Module): prefix=add_prefix("qkv_proj", prefix), ) + # Choose dtype of sinks based on attention backend: trtllm_mha requires float32, + # others can use bfloat16 + attn_backend = global_server_args_dict.get("attention_backend") + sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16 self.sinks = nn.Parameter( - torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False + torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False ) self.o_proj = RowParallelLinear(