[hotfix] use the original implementation in 8785 (#8994)
This commit is contained in:
@@ -553,6 +553,10 @@ class CommunicateSummableTensorPairFn:
|
|||||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||||
hidden_states,
|
hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hidden_states.data_ptr() is global_hidden_states.data_ptr():
|
||||||
|
hidden_states = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
||||||
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
||||||
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
||||||
|
|||||||
Reference in New Issue
Block a user