Use reduce scatter for DP (#8539)
This commit is contained in:
@@ -12,6 +12,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
GroupCoordinator,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
@@ -355,6 +356,17 @@ def dp_scatter(
|
||||
)
|
||||
|
||||
|
||||
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
||||
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
|
||||
get_tp_group().reduce_scatter_tensor(output, input)
|
||||
else:
|
||||
scattered_local_tokens = input.tensor_split(
|
||||
get_tensor_model_parallel_world_size()
|
||||
)[get_tensor_model_parallel_rank()]
|
||||
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
|
||||
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
|
||||
|
||||
|
||||
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
||||
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user