diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index aa49e4fc7..d4741144d 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -177,11 +177,24 @@ class EagleDraftInput: ) return kv_indices, cum_kv_seq_len, qo_indptr, None - def filter_batch(self, new_indices: torch.Tensor): - self.topk_p = self.topk_p[: len(new_indices)] - self.topk_index = self.topk_index[: len(new_indices)] - self.hidden_states = self.hidden_states[: len(new_indices)] - self.verified_id = self.verified_id[: len(new_indices)] + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + if has_been_filtered: + # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index` + # therefore, we don't need to filter the batch again in scheduler + if len(new_indices) != len(self.topk_p): + logger.warning( + f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen" + ) + self.topk_p = self.topk_p[: len(new_indices)] + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] + self.topk_index = self.topk_index[new_indices] + self.hidden_states = self.hidden_states[new_indices] + self.verified_id = self.verified_id[new_indices] def merge_batch(self, spec_info: EagleDraftInput): if self.hidden_states is None: diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 8da0549e9..972d7182d 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -836,6 +836,21 @@ class EAGLEWorker(TpModelWorker): assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info is batch.spec_info self.capture_for_decode(logits_output, forward_batch.spec_info) + has_finished, unfinished_req_index = False, [] + for i, req in enumerate(batch.reqs): + if req.finished(): + has_finished = True + else: + unfinished_req_index.append(i) + if has_finished: + unfinished_index_device = torch.tensor( + unfinished_req_index, + dtype=torch.int64, + device=batch.spec_info.topk_p.device, + ) + batch.spec_info.filter_batch( + unfinished_index_device, has_been_filtered=False + ) def forward_draft_extend_after_decode(self, batch: ScheduleBatch): assert isinstance(batch.spec_info, EagleDraftInput)