Use reduce scatter for DP (#8539)

This commit is contained in:
Trevor Morris
2025-08-06 16:21:26 -07:00
committed by GitHub
parent 92cc32d9fc
commit c0e84297c2
6 changed files with 73 additions and 18 deletions

View File

@@ -628,8 +628,10 @@ class ForwardBatch:
self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len():
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
# where transferred tokens should be padded to the same length.
# when DP gather mode is all gather, we will use
# all_gather_into_tensor to gather hidden states, where transferred
# tokens should be padded to the same length. We will also use
# reduce-scatter instead of all-reduce after MLP.
max_num_tokens = max(global_num_tokens)
global_num_tokens = [max_num_tokens] * sync_group_size
buffer_len = max_num_tokens * sync_group_size