Use reduce scatter for DP (#8539)

This commit is contained in:
Trevor Morris
2025-08-06 16:21:26 -07:00
committed by GitHub
parent 92cc32d9fc
commit c0e84297c2
6 changed files with 73 additions and 18 deletions

View File

@@ -12,6 +12,7 @@ import triton.language as tl
from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
@@ -355,6 +356,17 @@ def dp_scatter(
)
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
get_tp_group().reduce_scatter_tensor(output, input)
else:
scattered_local_tokens = input.tensor_split(
get_tensor_model_parallel_world_size()
)[get_tensor_model_parallel_rank()]
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().reduce_scatter_tensor(output, input)