DP Enhancement (#8280)

This commit is contained in:
Cheng Wan
2025-07-24 21:36:21 -07:00
committed by GitHub
parent 28d4d47280
commit c0fb25e949
20 changed files with 665 additions and 1116 deletions

View File

@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
attn_tp_all_gather_into_tensor(
hidden_states,
local_hidden_states,
)
return hidden_states
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
].clone(),
residual,
)
attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
attn_tp_all_gather_into_tensor(residual, local_residual)
if context.attn_dp_size != 1:
if context.attn_tp_rank == 0:
hidden_states += residual
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
*,
residual_input_mode,
):
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
input_hidden_states = hidden_states
hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
context.attn_tp_rank
]
attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
if hidden_states.shape[0] != 0:
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
attn_tp_all_gather_into_tensor(
hidden_states,
local_hidden_states,
)
return hidden_states, residual