From 5039d547724c557f915dc52e7317d7bd42271fe4 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Wed, 9 Apr 2025 05:55:14 +0800 Subject: [PATCH] Support 2x8xH100 for Llama 4 (#5159) --- python/sglang/srt/models/llama4.py | 96 ++++++++++++++++++++++++------ 1 file changed, 77 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index b3dbd50f0..4e4ba9a1e 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -27,6 +27,13 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.dp_attention import ( + dp_gather_partial, + dp_scatter, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -38,6 +45,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers @@ -143,20 +151,24 @@ class Llama4Attention(nn.Module): self.hidden_size = hidden_size self.use_rope = int((layer_id + 1) % 4 != 0) self.use_qk_norm = config.use_qk_norm and self.use_rope - tp_size = get_tensor_model_parallel_world_size() + + self.dp_size = get_attention_dp_size() + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= attn_tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % attn_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) self.head_dim = config.head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -183,6 +195,8 @@ class Llama4Attention(nn.Module): bias=bias, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, ) self.o_proj = RowParallelLinear( @@ -191,6 +205,9 @@ class Llama4Attention(nn.Module): bias=bias_o_proj, quant_config=quant_config, prefix=add_prefix("o_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + reduce_results=False, ) is_neox_style = True is_gguf = quant_config and quant_config.get_name() == "gguf" @@ -274,6 +291,9 @@ class Llama4DecoderLayer(nn.Module): rope_theta = config.rope_theta rope_scaling = config.rope_scaling max_position_embeddings = config.max_position_embeddings + self.dp_size = get_attention_dp_size() + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() self.self_attn = Llama4Attention( config=config, @@ -316,21 +336,58 @@ class Llama4DecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: + if hidden_states.shape[0] == 0: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Gather + if get_tensor_model_parallel_world_size() > 1: + # all gather and all reduce + if self.dp_size != 1: + if self.attn_tp_rank == 0: + hidden_states += residual + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + dp_scatter(residual, hidden_states, forward_batch) + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.feed_forward(hidden_states) + + # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter + # Scatter + if self.dp_size != 1: + # important: forward batch.gathered_buffer is used both after scatter and after gather. + # be careful about this! + hidden_states, global_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + dp_scatter(hidden_states, global_hidden_states, forward_batch) + return hidden_states, residual @@ -350,6 +407,7 @@ class Llama4Model(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("embed_tokens", prefix), + enable_tp=not global_server_args_dict["enable_dp_attention"], ) self.layers = make_layers( config.num_hidden_layers, @@ -385,7 +443,8 @@ class Llama4Model(nn.Module): forward_batch, residual, ) - hidden_states, _ = self.norm(hidden_states, residual) + if not forward_batch.forward_mode.is_idle(): + hidden_states, _ = self.norm(hidden_states, residual) if len(aux_hidden_states) == 0: return hidden_states @@ -394,7 +453,6 @@ class Llama4Model(nn.Module): class Llama4ForCausalLM(LlamaForCausalLM): - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"],