Fix DeepSeek bug causing 2.2% MMLU drop when TP!=DP (#4883)
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -1102,6 +1102,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
assert not (
|
||||
self.attn_tp_size != 1 and self.input_is_scattered
|
||||
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
@@ -1109,22 +1113,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
residual,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
||||
)
|
||||
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
@@ -1223,6 +1211,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
@@ -1230,19 +1220,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
residual,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
@@ -1296,7 +1278,10 @@ class DeepseekV2Model(nn.Module):
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if residual is None:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user