refactor llama4 dp attention logic (#7729)
This commit is contained in:
@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
dp_gather_partial,
|
|
||||||
dp_scatter,
|
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_local_attention_dp_size,
|
get_local_attention_dp_size,
|
||||||
@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
bias_o_proj=False,
|
bias_o_proj=False,
|
||||||
prefix=add_prefix("self_attn", prefix),
|
prefix=add_prefix("self_attn", prefix),
|
||||||
)
|
)
|
||||||
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
|
self.config = config
|
||||||
|
is_moe_layer = self._is_moe_layer(layer_id)
|
||||||
|
is_previous_moe_layer = self._is_moe_layer(layer_id - 1)
|
||||||
|
|
||||||
if is_moe_layer:
|
if is_moe_layer:
|
||||||
self.feed_forward = Llama4MoE(
|
self.feed_forward = Llama4MoE(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
config.hidden_size, eps=config.rms_norm_eps
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||||
|
layer_id=layer_id,
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
is_layer_sparse=is_moe_layer,
|
||||||
|
is_previous_layer_sparse=is_previous_moe_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_communicator = LayerCommunicator(
|
||||||
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
|
input_layernorm=self.input_layernorm,
|
||||||
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_moe_layer(self, layer_id: int) -> bool:
|
||||||
|
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if hidden_states.shape[0] == 0:
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||||
residual = hidden_states
|
hidden_states, residual, forward_batch
|
||||||
else:
|
)
|
||||||
# Self Attention
|
|
||||||
if residual is None:
|
if hidden_states.shape[0] != 0:
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
else:
|
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Gather
|
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||||
if get_tensor_model_parallel_world_size() > 1:
|
hidden_states, residual, forward_batch
|
||||||
# all gather and all reduce
|
)
|
||||||
if self.local_dp_size != 1:
|
|
||||||
if self.attn_tp_rank == 0:
|
|
||||||
hidden_states += residual
|
|
||||||
hidden_states, local_hidden_states = (
|
|
||||||
forward_batch.gathered_buffer,
|
|
||||||
hidden_states,
|
|
||||||
)
|
|
||||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
|
||||||
dp_scatter(residual, hidden_states, forward_batch)
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
else:
|
|
||||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
|
||||||
hidden_states, residual
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
|
||||||
hidden_states, residual
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
||||||
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
hidden_states, residual, forward_batch
|
||||||
# Scatter
|
)
|
||||||
if self.local_dp_size != 1:
|
|
||||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
|
||||||
# be careful about this!
|
|
||||||
hidden_states, global_hidden_states = (
|
|
||||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
||||||
hidden_states,
|
|
||||||
)
|
|
||||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user