Fix redundant kernel in sink dtype conversion (#8966)
This commit is contained in:
@@ -247,7 +247,7 @@ class GptOssAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.sinks = nn.Parameter(
|
||||
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
|
||||
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
|
||||
hidden_states, forward_batch, inner_state = intermediate_state
|
||||
if inner_state is None:
|
||||
return hidden_states
|
||||
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
|
||||
attn_output = self.attn(*inner_state, sinks=self.sinks)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user