From 570d33437bf0b4ac42e00ad468ddc43f9e0b376f Mon Sep 17 00:00:00 2001 From: Xiaoze Fan Date: Thu, 17 Jul 2025 01:57:46 +0800 Subject: [PATCH] [Feature] Layer-wise Prefill (#7634) Signed-off-by: jason-fxz Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/managers/schedule_batch.py | 5 ++ .../srt/model_executor/forward_batch_info.py | 13 ++++ .../sglang/srt/model_executor/model_runner.py | 37 ++++++++++- python/sglang/srt/models/gemma.py | 48 ++++++++++++++ python/sglang/srt/models/gemma2.py | 51 +++++++++++++++ python/sglang/srt/models/gemma3_causal.py | 63 +++++++++++++++++++ python/sglang/srt/models/llama.py | 41 ++++++++++++ python/sglang/srt/models/qwen.py | 37 +++++++++++ python/sglang/srt/models/qwen2.py | 41 ++++++++++++ python/sglang/srt/models/qwen2_moe.py | 44 +++++++++++++ python/sglang/srt/models/qwen3.py | 42 ++++++++++++- python/sglang/srt/models/qwen3_moe.py | 43 +++++++++++++ python/sglang/srt/two_batch_overlap.py | 1 + 13 files changed, 464 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 1a48b0553..c2750d072 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1328,6 +1328,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.model_config.vocab_size, ) + def prepare_for_split_prefill(self): + self.prepare_for_extend() + # For split prefill, we need to set the forward mode to SPLIT_PREFILL + self.forward_mode = ForwardMode.SPLIT_PREFILL + def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED running_bs = running_batch.batch_size() diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 7ed8eb1d4..fde60e0e5 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -68,6 +68,8 @@ class ForwardMode(IntEnum): MIXED = auto() # No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated. IDLE = auto() + # Split Prefill for PD multiplexing + SPLIT_PREFILL = auto() # Used in speculative decoding: verify a batch in the target model. TARGET_VERIFY = auto() @@ -95,6 +97,9 @@ class ForwardMode(IntEnum): def is_mixed(self): return self == ForwardMode.MIXED + def is_split_prefill(self): + return self == ForwardMode.SPLIT_PREFILL + def is_idle(self): return self == ForwardMode.IDLE @@ -194,6 +199,14 @@ class ForwardBatch: extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None + # For split prefill + # intermediate values for split prefill + hidden_states: torch.Tensor = None + residual: torch.Tensor = None + model_specific_states: Dict[str, any] = None + # current split index of layer + split_index: int = 0 + # For MLA chunked prefix cache used in chunked prefill # Tell attention backend whether the kv cache needs to be attended in current pass attn_attend_prefix_cache: Optional[bool] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a7885a5e3..12db1d055 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1513,11 +1513,34 @@ class ModelRunner: **kwargs, ) + def forward_split_prefill( + self, + forward_batch: ForwardBatch, + reinit_attn_backend: bool = False, + forward_count: int = 1, + ) -> LogitsProcessorOutput: + if forward_batch.split_index == 0 or reinit_attn_backend: + self.attn_backend.init_forward_metadata(forward_batch) + next_split_index = min( + forward_batch.split_index + forward_count, + self.model_config.num_hidden_layers, + ) + ret = self.model.forward_split_prefill( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + (forward_batch.split_index, next_split_index), + ) + forward_batch.split_index = next_split_index + return ret + def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, + reinit_attn_backend: bool = False, + split_forward_count: int = 1, ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: self.forward_pass_id += 1 @@ -1526,7 +1549,11 @@ class ModelRunner: forward_batch, ): output = self._forward_raw( - forward_batch, skip_attn_backend_init, pp_proxy_tensors + forward_batch, + skip_attn_backend_init, + pp_proxy_tensors, + reinit_attn_backend, + split_forward_count, ) if self.eplb_manager is not None: @@ -1539,6 +1566,8 @@ class ModelRunner: forward_batch: ForwardBatch, skip_attn_backend_init: bool, pp_proxy_tensors: Optional[PPProxyTensors], + reinit_attn_backend: bool = False, + split_forward_count: int = 1, ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: can_run_cuda_graph = bool( forward_batch.forward_mode.is_cuda_graph() @@ -1559,6 +1588,12 @@ class ModelRunner: skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, ) + elif forward_batch.forward_mode.is_split_prefill(): + ret = self.forward_split_prefill( + forward_batch, + reinit_attn_backend=reinit_attn_backend, + forward_count=split_forward_count, + ) elif forward_batch.forward_mode.is_idle(): ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) else: diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index d8074487c..1ecb5011f 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module): input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + + # Normalize the embedding by sqrt(hidden_size) + forward_batch.hidden_states *= self.model.config.hidden_size**0.5 + + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + forward_batch.hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + + # logits process + result = self.logits_processor( + input_ids, + forward_batch.hidden_states, + self.model.embed_tokens, + forward_batch, + ) + else: + result = None + + return result + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 9ee892bb7..ee490d083 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -381,6 +381,57 @@ class Gemma2ForCausalLM(nn.Module): input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + + # Normalize + normalizer = torch.tensor( + self.model.config.hidden_size**0.5, dtype=torch.float16 + ) + forward_batch.hidden_states *= normalizer + + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + forward_batch.hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + + # logits process + result = self.logits_processor( + input_ids, + forward_batch.hidden_states, + self.model.embed_tokens, + forward_batch, + ) + else: + result = None + + return result + def get_hidden_dim(self, module_name): # return input_dim, output_dim if module_name in ["q_proj", "qkv_proj"]: diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index f5bff8fc4..5b6145aff 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -647,6 +647,69 @@ class Gemma3ForCausalLM(PreTrainedModel): input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + hidden_states = self.model.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if positions.dim() == 1: + positions = einops.rearrange(positions, "s -> 1 s") + position_embeddings_global = self.model.rotary_emb(hidden_states, positions) + position_embeddings_local = self.model.rotary_emb_local( + hidden_states, positions + ) + + forward_batch.hidden_states = hidden_states + forward_batch.model_specific_states = { + "positions": positions, + "position_embeddings_global": position_embeddings_global, + "position_embeddings_local": position_embeddings_local, + } + + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + layer_output = layer( + positions=forward_batch.model_specific_states["positions"], + position_embeddings_global=forward_batch.model_specific_states[ + "position_embeddings_global" + ], + position_embeddings_local=forward_batch.model_specific_states[ + "position_embeddings_local" + ], + hidden_states=forward_batch.hidden_states, + forward_batch=forward_batch, + ) + forward_batch.hidden_states = layer_output[0] + + if end == self.model.config.num_hidden_layers: + # norm + forward_batch.hidden_states = self.model.norm(forward_batch.hidden_states) + + # logits process + result = self.logits_processor( + input_ids, + forward_batch.hidden_states, + self.model.embed_tokens, + forward_batch, + ) + else: + result = None + + return result + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index f8cfe859b..d1614935b 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -480,6 +480,47 @@ class LlamaForCausalLM(nn.Module): else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ) -> Optional[LogitsProcessorOutput]: + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index f0660f62d..009650411 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -15,6 +15,7 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1 +import time from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module): input_ids, hidden_states, self.lm_head, forward_batch ) + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + ): + start, end = split_interval + # embed + if start == 0: + forward_batch.hidden_states = self.transformer.wte(input_ids) + + # decoder layer + for i in range(start, end): + layer = self.transformer.h[i] + forward_batch.hidden_states = layer( + positions, + forward_batch.hidden_states, + forward_batch, + ) + + if end == self.transformer.config.num_hidden_layers: + # norm + forward_batch.hidden_states = self.transformer.ln_f( + forward_batch.hidden_states + ) + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index e3670bb55..1696bdfa9 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module): else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 92637d73b..fe2636ab7 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -406,6 +406,7 @@ class Qwen2MoeModel(nn.Module): alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.pp_group = get_pp_group() @@ -554,6 +555,49 @@ class Qwen2MoeForCausalLM(nn.Module): else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + + # decoder layer + for i in range(start, end): + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 9c3659839..6289e61e7 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -1,5 +1,4 @@ # Adapted from qwen2.py - import logging from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple @@ -367,6 +366,47 @@ class Qwen3ForCausalLM(nn.Module): else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 7c7c7551b..75d3b475c 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -745,6 +745,49 @@ class Qwen3MoeForCausalLM(nn.Module): else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + + # decoder layer + for i in range(start, end): + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index fc419b03c..3fdf2a1f7 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -500,6 +500,7 @@ class TboForwardBatchPreparer: "capture_hidden_mode", "padded_static_len", "mrope_positions", # only used by qwen2-vl, thus not care + "split_index", # for split prefill ]: output_dict[key] = getattr(batch, key) if not batch.forward_mode.is_target_verify():