Fix broken trtllm_mha attn backend with gpt-oss (#9161)
This commit is contained in:
@@ -293,8 +293,12 @@ class GptOssAttention(nn.Module):
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
|
||||
# others can use bfloat16
|
||||
attn_backend = global_server_args_dict.get("attention_backend")
|
||||
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
|
||||
self.sinks = nn.Parameter(
|
||||
torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
|
||||
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
|
||||
Reference in New Issue
Block a user