diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index c6743e344..ab884ad9d 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -90,7 +90,7 @@ class LlamaMLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x): + def forward(self, x, forward_batch=None): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 73c707508..a84f3106f 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + ForwardMode, + PPProxyTensors, +) from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers @@ -81,6 +85,7 @@ class Llama4MoE(nn.Module): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.top_k = config.num_experts_per_tok + self.device_module = torch.get_device_module() intermediate_size_moe = config.intermediate_size self.router = ReplicatedLinear( @@ -113,7 +118,25 @@ class Llama4MoE(nn.Module): reduce_results=False, # We need to do scatter before reduce ) - def forward(self, hidden_states): + def forward(self, hidden_states, forward_batch: ForwardBatch): + shared_out, routed_out = self._forward_core( + hidden_states, forward_batch.forward_mode + ) + + out_aD = routed_out + shared_out + + if self.tp_size > 1: + out_aD = tensor_model_parallel_all_reduce(out_aD) + + return out_aD + + def _forward_core(self, hidden_states, forward_mode: ForwardMode): + if hidden_states.shape[0] < 4: + return self._forward_core_shared_routed_overlap(hidden_states) + else: + return self._forward_core_normal(hidden_states) + + def _forward_core_normal(self, hidden_states): # router_scores: [num_tokens, num_experts] router_logits, _ = self.router(hidden_states) shared_out = self.shared_expert(hidden_states) @@ -121,12 +144,35 @@ class Llama4MoE(nn.Module): hidden_states=hidden_states, router_logits=router_logits, ) - out_aD = routed_out + shared_out + return shared_out, routed_out - if self.tp_size > 1: - out_aD = tensor_model_parallel_all_reduce(out_aD) + def _forward_core_shared_routed_overlap(self, hidden_states): + alt_stream = _get_or_create_alt_stream(self.device_module) - return out_aD + alt_stream.wait_stream(self.device_module.current_stream()) + + shared_out = self.shared_expert(hidden_states) + + with self.device_module.stream(alt_stream): + # router_scores: [num_tokens, num_experts] + router_logits, _ = self.router(hidden_states) + routed_out = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + self.device_module.current_stream().wait_stream(alt_stream) + + return shared_out, routed_out + + +_alt_stream = None + + +def _get_or_create_alt_stream(device_module): + global _alt_stream + if _alt_stream is None: + _alt_stream = device_module.Stream() + return _alt_stream class Llama4Attention(nn.Module): @@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module): ) # Fully Connected - hidden_states = self.feed_forward(hidden_states) + hidden_states = self.feed_forward(hidden_states, forward_batch) # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter # Scatter