[Fix] Fix split prefill with fa3. (#11428)
This commit is contained in:
@@ -137,7 +137,10 @@ class LogitsMetadata:
|
||||
@classmethod
|
||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
(
|
||||
forward_batch.forward_mode.is_extend()
|
||||
or forward_batch.forward_mode.is_split_prefill()
|
||||
)
|
||||
and forward_batch.return_logprob
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
):
|
||||
@@ -389,8 +392,8 @@ class LogitsProcessor(nn.Module):
|
||||
input_logprob_indices = None
|
||||
elif (
|
||||
logits_metadata.forward_mode.is_extend()
|
||||
and not logits_metadata.extend_return_logprob
|
||||
):
|
||||
or logits_metadata.forward_mode.is_split_prefill()
|
||||
) and not logits_metadata.extend_return_logprob:
|
||||
# Prefill without input logprobs.
|
||||
if logits_metadata.padded_static_len < 0:
|
||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||
|
||||
@@ -112,6 +112,7 @@ class ForwardMode(IntEnum):
|
||||
self == ForwardMode.EXTEND
|
||||
or self == ForwardMode.DRAFT_EXTEND
|
||||
or self == ForwardMode.MIXED
|
||||
or self == ForwardMode.SPLIT_PREFILL
|
||||
)
|
||||
|
||||
def is_cuda_graph(self):
|
||||
|
||||
Reference in New Issue
Block a user