From c8452551ceec95d098df96c3e3c7a1c30066828b Mon Sep 17 00:00:00 2001 From: ykcombat <99869808+ykcombat@users.noreply.github.com> Date: Sat, 11 Oct 2025 22:03:28 +0800 Subject: [PATCH] [Fix] Fix split prefill with fa3. (#11428) --- python/sglang/srt/layers/logits_processor.py | 9 ++++++--- python/sglang/srt/model_executor/forward_batch_info.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index a95e2011a..6a7d330b5 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -137,7 +137,10 @@ class LogitsMetadata: @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): if ( - forward_batch.forward_mode.is_extend() + ( + forward_batch.forward_mode.is_extend() + or forward_batch.forward_mode.is_split_prefill() + ) and forward_batch.return_logprob and not forward_batch.forward_mode.is_target_verify() ): @@ -389,8 +392,8 @@ class LogitsProcessor(nn.Module): input_logprob_indices = None elif ( logits_metadata.forward_mode.is_extend() - and not logits_metadata.extend_return_logprob - ): + or logits_metadata.forward_mode.is_split_prefill() + ) and not logits_metadata.extend_return_logprob: # Prefill without input logprobs. if logits_metadata.padded_static_len < 0: last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 59e912022..297aef2d2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -112,6 +112,7 @@ class ForwardMode(IntEnum): self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.MIXED + or self == ForwardMode.SPLIT_PREFILL ) def is_cuda_graph(self):