From d23cb9a01ed7f7e39f40e3f5ad7d271d3ac52ce2 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Mon, 10 Feb 2025 04:21:49 -0800 Subject: [PATCH] [Eagle] reduce one draft forward (#3468) --- python/sglang/srt/layers/attention/flashinfer_backend.py | 2 +- python/sglang/srt/speculative/eagle_worker.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index e899bcb7e..99708135a 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -947,7 +947,7 @@ class FlashInferMultiStepDraftBackend: triton.next_power_of_2(bs), ) - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ : seq_lens_sum * self.topk + bs * (i + 1) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 45798ba58..8c24f2aa5 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -234,6 +234,10 @@ class EAGLEWorker(TpModelWorker): token_list.append(tree_info[1]) parents_list.append(tree_info[2]) + # we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here + if i == self.speculative_num_steps - 1: + break + # Set inputs forward_batch.input_ids = input_ids forward_batch.out_cache_loc = out_cache_loc[