Add fast decode plan for flashinfer mla (#3987)
This commit is contained in:
@@ -582,6 +582,9 @@ class ScheduleBatch:
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
|
||||
# For decode
|
||||
decode_seq_lens: List[int] = None
|
||||
|
||||
# For extend and mixed chunekd prefill
|
||||
prefix_lens: List[int] = None
|
||||
extend_lens: List[int] = None
|
||||
@@ -1168,8 +1171,10 @@ class ScheduleBatch:
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
decode_seq_lens = self.seq_lens.cpu()
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
decode_seq_lens = None
|
||||
extend_seq_lens = self.extend_lens
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
@@ -1194,6 +1199,7 @@ class ScheduleBatch:
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
decode_seq_lens=decode_seq_lens,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
@@ -1267,6 +1273,9 @@ class ModelWorkerBatch:
|
||||
global_num_tokens: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
|
||||
# For decode
|
||||
decode_seq_lens: Optional[torch.Tensor]
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int]
|
||||
extend_seq_lens: Optional[List[int]]
|
||||
|
||||
Reference in New Issue
Block a user