diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 98ef0775e..6691cf944 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -247,7 +247,7 @@ class GptOssAttention(nn.Module): ) self.sinks = nn.Parameter( - torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False + torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False ) self.o_proj = RowParallelLinear( @@ -301,7 +301,7 @@ class GptOssAttention(nn.Module): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states - attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32)) + attn_output = self.attn(*inner_state, sinks=self.sinks) output, _ = self.o_proj(attn_output) return output