Fix broken trtllm_mha attn backend with gpt-oss (#9161)
This commit is contained in:
@@ -293,8 +293,12 @@ class GptOssAttention(nn.Module):
|
|||||||
prefix=add_prefix("qkv_proj", prefix),
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
|
||||||
|
# others can use bfloat16
|
||||||
|
attn_backend = global_server_args_dict.get("attention_backend")
|
||||||
|
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
|
||||||
self.sinks = nn.Parameter(
|
self.sinks = nn.Parameter(
|
||||||
torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
|
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
|
|||||||
Reference in New Issue
Block a user