[DP Attention] Refactor: adding some utility functions (#9136)
This commit is contained in:
@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_global_dp_buffer,
|
||||
get_local_dp_buffer,
|
||||
)
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
|
||||
context: CommunicateContext,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
get_local_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(
|
||||
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
):
|
||||
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
||||
residual, local_residual = (
|
||||
torch.empty_like(
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
|
||||
),
|
||||
get_local_dp_buffer(),
|
||||
residual,
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(residual, local_residual)
|
||||
@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
residual = hidden_states
|
||||
hidden_states = layernorm(hidden_states)
|
||||
hidden_states, local_hidden_states = (
|
||||
torch.empty_like(forward_batch.gathered_buffer),
|
||||
get_global_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
|
||||
allow_reduce_scatter: bool = False,
|
||||
):
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
get_local_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
||||
@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
get_local_dp_buffer(),
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(
|
||||
|
||||
Reference in New Issue
Block a user