DP Enhancement (#8280)
This commit is contained in:
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
attn_tp_all_gather_into_tensor,
|
||||
attn_tp_reduce_scatter_tensor,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(context.attn_tp_size)),
|
||||
attn_tp_all_gather_into_tensor(
|
||||
hidden_states,
|
||||
local_hidden_states,
|
||||
)
|
||||
return hidden_states
|
||||
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
].clone(),
|
||||
residual,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(residual, local_residual)
|
||||
if context.attn_dp_size != 1:
|
||||
if context.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
*,
|
||||
residual_input_mode,
|
||||
):
|
||||
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
|
||||
hidden_states = tensor_list[context.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
input_hidden_states = hidden_states
|
||||
hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
|
||||
context.attn_tp_rank
|
||||
]
|
||||
attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
|
||||
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
|
||||
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
|
||||
if hidden_states.shape[0] != 0:
|
||||
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(context.attn_tp_size)),
|
||||
attn_tp_all_gather_into_tensor(
|
||||
hidden_states,
|
||||
local_hidden_states,
|
||||
)
|
||||
return hidden_states, residual
|
||||
|
||||
Reference in New Issue
Block a user