[DP] fix: engine crash when decode batch is padded (#8995)
This commit is contained in:
@@ -408,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
):
|
||||
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[
|
||||
: forward_batch.input_ids.shape[0]
|
||||
].clone(),
|
||||
torch.empty_like(
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
|
||||
),
|
||||
residual,
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(residual, local_residual)
|
||||
@@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
|
||||
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
||||
use_layer_norm_before_gather = context.attn_tp_size == 1
|
||||
if use_layer_norm_before_gather:
|
||||
residual.copy_(hidden_states)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = layernorm(hidden_states)
|
||||
|
||||
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
|
||||
residual = hidden_states
|
||||
hidden_states = layernorm(hidden_states)
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
torch.empty_like(forward_batch.gathered_buffer),
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
@@ -552,10 +550,6 @@ class CommunicateSummableTensorPairFn:
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
if hidden_states.data_ptr() is global_hidden_states.data_ptr():
|
||||
hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
||||
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
||||
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user