Fix redundant kernel in sink dtype conversion (#8966)
This commit is contained in:
@@ -247,7 +247,7 @@ class GptOssAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.sinks = nn.Parameter(
|
self.sinks = nn.Parameter(
|
||||||
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
|
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
|
|||||||
hidden_states, forward_batch, inner_state = intermediate_state
|
hidden_states, forward_batch, inner_state = intermediate_state
|
||||||
if inner_state is None:
|
if inner_state is None:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
|
attn_output = self.attn(*inner_state, sinks=self.sinks)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user