Do layernorm before allgather for DP attention (#8631)
This commit is contained in:
@@ -404,14 +404,24 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
if context.attn_dp_size != 1:
|
if context.attn_dp_size != 1:
|
||||||
if context.attn_tp_rank == 0:
|
if context.attn_tp_rank == 0:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
|
||||||
|
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
||||||
|
use_layer_norm_before_gather = context.attn_tp_size == 1
|
||||||
|
if use_layer_norm_before_gather:
|
||||||
|
residual.copy_(hidden_states)
|
||||||
|
if hidden_states.shape[0] != 0:
|
||||||
|
hidden_states = layernorm(hidden_states)
|
||||||
|
|
||||||
hidden_states, local_hidden_states = (
|
hidden_states, local_hidden_states = (
|
||||||
forward_batch.gathered_buffer,
|
forward_batch.gathered_buffer,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
)
|
)
|
||||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||||
dp_scatter(residual, hidden_states, forward_batch)
|
|
||||||
if hidden_states.shape[0] != 0:
|
if not use_layer_norm_before_gather:
|
||||||
hidden_states = layernorm(hidden_states)
|
dp_scatter(residual, hidden_states, forward_batch)
|
||||||
|
if hidden_states.shape[0] != 0:
|
||||||
|
hidden_states = layernorm(hidden_states)
|
||||||
else:
|
else:
|
||||||
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
||||||
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
||||||
|
|||||||
Reference in New Issue
Block a user