[Fix] Fix split prefill with fa3. (#11428)

This commit is contained in:
ykcombat
2025-10-11 22:03:28 +08:00
committed by GitHub
parent bf3e7149be
commit c8452551ce
2 changed files with 7 additions and 3 deletions

View File

@@ -137,7 +137,10 @@ class LogitsMetadata:
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if (
forward_batch.forward_mode.is_extend()
(
forward_batch.forward_mode.is_extend()
or forward_batch.forward_mode.is_split_prefill()
)
and forward_batch.return_logprob
and not forward_batch.forward_mode.is_target_verify()
):
@@ -389,8 +392,8 @@ class LogitsProcessor(nn.Module):
input_logprob_indices = None
elif (
logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob
):
or logits_metadata.forward_mode.is_split_prefill()
) and not logits_metadata.extend_return_logprob:
# Prefill without input logprobs.
if logits_metadata.padded_static_len < 0:
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1

View File

@@ -112,6 +112,7 @@ class ForwardMode(IntEnum):
self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL
)
def is_cuda_graph(self):