refactor llama4 dp attention logic (#7729)
This commit is contained in:
@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
|
||||
bias_o_proj=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
|
||||
self.config = config
|
||||
is_moe_layer = self._is_moe_layer(layer_id)
|
||||
is_previous_moe_layer = self._is_moe_layer(layer_id - 1)
|
||||
|
||||
if is_moe_layer:
|
||||
self.feed_forward = Llama4MoE(
|
||||
config=config,
|
||||
@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
layer_id=layer_id,
|
||||
num_layers=config.num_hidden_layers,
|
||||
is_layer_sparse=is_moe_layer,
|
||||
is_previous_layer_sparse=is_previous_moe_layer,
|
||||
)
|
||||
|
||||
self.layer_communicator = LayerCommunicator(
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
)
|
||||
|
||||
def _is_moe_layer(self, layer_id: int) -> bool:
|
||||
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.local_dp_size != 1:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
||||
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.local_dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
Reference in New Issue
Block a user