[Fix] Fix split prefill with fa3. (#11428)
This commit is contained in:
@@ -137,7 +137,10 @@ class LogitsMetadata:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_extend()
|
(
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
or forward_batch.forward_mode.is_split_prefill()
|
||||||
|
)
|
||||||
and forward_batch.return_logprob
|
and forward_batch.return_logprob
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
):
|
):
|
||||||
@@ -389,8 +392,8 @@ class LogitsProcessor(nn.Module):
|
|||||||
input_logprob_indices = None
|
input_logprob_indices = None
|
||||||
elif (
|
elif (
|
||||||
logits_metadata.forward_mode.is_extend()
|
logits_metadata.forward_mode.is_extend()
|
||||||
and not logits_metadata.extend_return_logprob
|
or logits_metadata.forward_mode.is_split_prefill()
|
||||||
):
|
) and not logits_metadata.extend_return_logprob:
|
||||||
# Prefill without input logprobs.
|
# Prefill without input logprobs.
|
||||||
if logits_metadata.padded_static_len < 0:
|
if logits_metadata.padded_static_len < 0:
|
||||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class ForwardMode(IntEnum):
|
|||||||
self == ForwardMode.EXTEND
|
self == ForwardMode.EXTEND
|
||||||
or self == ForwardMode.DRAFT_EXTEND
|
or self == ForwardMode.DRAFT_EXTEND
|
||||||
or self == ForwardMode.MIXED
|
or self == ForwardMode.MIXED
|
||||||
|
or self == ForwardMode.SPLIT_PREFILL
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_cuda_graph(self):
|
def is_cuda_graph(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user