Fix DeepSeek bug causing 2.2% MMLU drop when TP!=DP (#4883)
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -1102,6 +1102,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
self.attn_tp_size != 1 and self.input_is_scattered
|
||||||
|
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@@ -1109,22 +1113,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
|
||||||
hidden_states, local_hidden_states = (
|
|
||||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
||||||
hidden_states,
|
|
||||||
)
|
|
||||||
tp_all_gather(
|
|
||||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
|
||||||
)
|
|
||||||
residual, local_residual = (
|
|
||||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
||||||
residual,
|
|
||||||
)
|
|
||||||
tp_all_gather(
|
|
||||||
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
|
||||||
)
|
|
||||||
|
|
||||||
# Gather
|
# Gather
|
||||||
if get_tensor_model_parallel_world_size() > 1:
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
# all gather and all reduce
|
# all gather and all reduce
|
||||||
@@ -1223,6 +1211,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||||
|
|
||||||
if self.is_last_layer and self.attn_tp_size != 1:
|
if self.is_last_layer and self.attn_tp_size != 1:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = None
|
||||||
hidden_states, local_hidden_states = (
|
hidden_states, local_hidden_states = (
|
||||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1230,19 +1220,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
tp_all_gather(
|
tp_all_gather(
|
||||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||||
)
|
)
|
||||||
residual, local_residual = (
|
|
||||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
||||||
residual,
|
|
||||||
)
|
|
||||||
tp_all_gather(
|
|
||||||
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
|
||||||
)
|
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Model(nn.Module):
|
class DeepseekV2Model(nn.Module):
|
||||||
|
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1296,7 +1278,10 @@ class DeepseekV2Model(nn.Module):
|
|||||||
positions, hidden_states, forward_batch, residual
|
positions, hidden_states, forward_batch, residual
|
||||||
)
|
)
|
||||||
if not forward_batch.forward_mode.is_idle():
|
if not forward_batch.forward_mode.is_idle():
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
if residual is None:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user