Use reduce scatter for DP (#8539)
This commit is contained in:
@@ -628,8 +628,10 @@ class ForwardBatch:
|
||||
self.dp_padding_mode = dp_padding_mode
|
||||
|
||||
if dp_padding_mode.is_max_len():
|
||||
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
|
||||
# where transferred tokens should be padded to the same length.
|
||||
# when DP gather mode is all gather, we will use
|
||||
# all_gather_into_tensor to gather hidden states, where transferred
|
||||
# tokens should be padded to the same length. We will also use
|
||||
# reduce-scatter instead of all-reduce after MLP.
|
||||
max_num_tokens = max(global_num_tokens)
|
||||
global_num_tokens = [max_num_tokens] * sync_group_size
|
||||
buffer_len = max_num_tokens * sync_group_size
|
||||
|
||||
Reference in New Issue
Block a user