From b730aa6b9e577670f1967b65ea9f24a32e0aca8d Mon Sep 17 00:00:00 2001 From: 996_icu <85502239+josephydu@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:46:43 +0800 Subject: [PATCH] [EAGLE] Fix some boundary situation when retract reqs and req's max token = 1 (#2939) Co-authored-by: josephyou --- python/sglang/srt/managers/schedule_batch.py | 2 ++ python/sglang/srt/speculative/eagle_utils.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 040afe3d3..d9af81515 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1112,6 +1112,8 @@ class ScheduleBatch: self.has_grammar = any(req.grammar for req in self.reqs) self.sampling_info.filter_batch(keep_indices, new_indices) + if self.spec_info: + self.spec_info.filter_batch(new_indices) def merge_batch(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 1a324000c..ac16f6c53 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -228,6 +228,14 @@ class EAGLEDraftInput(SpecInfo): assert len(batch.extend_lens) == 1 batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) + def filter_batch( + self, + new_indices: torch.Tensor, + ): + self.sample_output = self.sample_output[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + def prepare_for_decode(self, batch: ScheduleBatch): prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) top = torch.topk(prob, self.topk, dim=-1)