Fix one wasted kernel in DeepSeek and minor refactor (#6316)
This commit is contained in:
@@ -1336,28 +1336,16 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.attn_tp_size != 1:
|
if self.attn_tp_size != 1:
|
||||||
if self.input_is_scattered:
|
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
hidden_states = tensor_list[self.attn_tp_rank]
|
||||||
hidden_states = tensor_list[self.attn_tp_rank]
|
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
if not self.input_is_scattered:
|
||||||
if hidden_states.shape[0] != 0:
|
residual = residual.tensor_split(self.attn_tp_size)[self.attn_tp_rank]
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
|
||||||
hidden_states, residual
|
if hidden_states.shape[0] != 0:
|
||||||
)
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
else:
|
hidden_states, residual
|
||||||
if self.attn_tp_rank == 0:
|
)
|
||||||
hidden_states += residual
|
|
||||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
|
||||||
hidden_states = tensor_list[self.attn_tp_rank]
|
|
||||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
|
||||||
residual = hidden_states
|
|
||||||
if hidden_states.shape[0] != 0:
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
else:
|
|
||||||
if hidden_states.shape[0] != 0:
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
|
||||||
hidden_states, residual
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
self._enable_moe_dense_fully_dp()
|
self._enable_moe_dense_fully_dp()
|
||||||
@@ -1859,7 +1847,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
q_a_proj_name in cached_a_proj
|
q_a_proj_name in cached_a_proj
|
||||||
and kv_a_proj_name in cached_a_proj
|
and kv_a_proj_name in cached_a_proj
|
||||||
):
|
):
|
||||||
|
|
||||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||||
fused_weight = torch.cat(
|
fused_weight = torch.cat(
|
||||||
|
|||||||
Reference in New Issue
Block a user