From 6b7c24712cdab4c9ea332bd073a16d23f6bfc384 Mon Sep 17 00:00:00 2001 From: Nicolas Castet <26874160+nvcastet@users.noreply.github.com> Date: Wed, 13 Aug 2025 18:40:55 -0500 Subject: [PATCH] Fix broken trtllm_mha attn backend with gpt-oss (#9161) --- python/sglang/srt/models/gpt_oss.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index d04434600..e15bc5dc2 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -293,8 +293,12 @@ class GptOssAttention(nn.Module): prefix=add_prefix("qkv_proj", prefix), ) + # Choose dtype of sinks based on attention backend: trtllm_mha requires float32, + # others can use bfloat16 + attn_backend = global_server_args_dict.get("attention_backend") + sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16 self.sinks = nn.Parameter( - torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False + torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False ) self.o_proj = RowParallelLinear(