Fix broken trtllm_mha attn backend with gpt-oss (#9161)

This commit is contained in:
Nicolas Castet
2025-08-13 18:40:55 -05:00
committed by GitHub
parent a027a9b4b3
commit 6b7c24712c

View File

@@ -293,8 +293,12 @@ class GptOssAttention(nn.Module):
prefix=add_prefix("qkv_proj", prefix),
)
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
# others can use bfloat16
attn_backend = global_server_args_dict.get("attention_backend")
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
)
self.o_proj = RowParallelLinear(