diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index e899bcb7e..99708135a 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -947,7 +947,7 @@ class FlashInferMultiStepDraftBackend: triton.next_power_of_2(bs), ) - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ : seq_lens_sum * self.topk + bs * (i + 1) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 45798ba58..8c24f2aa5 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -234,6 +234,10 @@ class EAGLEWorker(TpModelWorker): token_list.append(tree_info[1]) parents_list.append(tree_info[2]) + # we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here + if i == self.speculative_num_steps - 1: + break + # Set inputs forward_batch.input_ids = input_ids forward_batch.out_cache_loc = out_cache_loc[