[DP Attention] Refactor: adding some utility functions (#9136)

This commit is contained in:
Cheng Wan
2025-08-13 21:08:06 -07:00
committed by GitHub
parent b3363cc1aa
commit b87aacb5c5
21 changed files with 216 additions and 159 deletions

View File

@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_global_dp_buffer,
get_local_dp_buffer,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
context: CommunicateContext,
) -> torch.Tensor:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
get_local_dp_buffer(),
hidden_states,
)
attn_tp_all_gather_into_tensor(
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
):
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
residual, local_residual = (
torch.empty_like(
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
),
get_local_dp_buffer(),
residual,
)
attn_tp_all_gather_into_tensor(residual, local_residual)
@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
residual = hidden_states
hidden_states = layernorm(hidden_states)
hidden_states, local_hidden_states = (
torch.empty_like(forward_batch.gathered_buffer),
get_global_dp_buffer(),
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
allow_reduce_scatter: bool = False,
):
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
get_local_dp_buffer(),
hidden_states,
)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
get_local_dp_buffer(),
hidden_states,
)
attn_tp_all_gather_into_tensor(