diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index a95e2011a..6a7d330b5 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -137,7 +137,10 @@ class LogitsMetadata: @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): if ( - forward_batch.forward_mode.is_extend() + ( + forward_batch.forward_mode.is_extend() + or forward_batch.forward_mode.is_split_prefill() + ) and forward_batch.return_logprob and not forward_batch.forward_mode.is_target_verify() ): @@ -389,8 +392,8 @@ class LogitsProcessor(nn.Module): input_logprob_indices = None elif ( logits_metadata.forward_mode.is_extend() - and not logits_metadata.extend_return_logprob - ): + or logits_metadata.forward_mode.is_split_prefill() + ) and not logits_metadata.extend_return_logprob: # Prefill without input logprobs. if logits_metadata.padded_static_len < 0: last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 59e912022..297aef2d2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -112,6 +112,7 @@ class ForwardMode(IntEnum): self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.MIXED + or self == ForwardMode.SPLIT_PREFILL ) def is_cuda_graph(self):