Fix redundant kernel in sink dtype conversion (#8966)

This commit is contained in:
fzyzcjy
2025-08-09 15:34:49 +08:00
committed by GitHub
parent 442534aa44
commit d3e67deb1b

View File

@@ -247,7 +247,7 @@ class GptOssAttention(nn.Module):
)
self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
)
self.o_proj = RowParallelLinear(
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
attn_output = self.attn(*inner_state, sinks=self.sinks)
output, _ = self.o_proj(attn_output)
return output