Revert "Add fast decode plan for flashinfer mla" (#4008)

This commit is contained in:
Lianmin Zheng
2025-03-02 19:29:10 -08:00
committed by GitHub
parent fa56106731
commit 9e1014cf99
9 changed files with 52 additions and 156 deletions

View File

@@ -582,9 +582,6 @@ class ScheduleBatch:
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
# For decode
decode_seq_lens: List[int] = None
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
@@ -1171,10 +1168,8 @@ class ScheduleBatch:
def get_model_worker_batch(self):
if self.forward_mode.is_decode_or_idle():
decode_seq_lens = self.seq_lens.cpu()
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
decode_seq_lens = None
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
@@ -1199,7 +1194,6 @@ class ScheduleBatch:
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
decode_seq_lens=decode_seq_lens,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
@@ -1273,9 +1267,6 @@ class ModelWorkerBatch:
global_num_tokens: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For decode
decode_seq_lens: Optional[torch.Tensor]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]