From f8d48fd31146f7e0c26ca24c861a3cd8856a3f9d Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Mon, 23 Jun 2025 11:23:25 -0700 Subject: [PATCH] Fix dtype for idle input in spec decoding (#7456) --- python/sglang/srt/speculative/eagle_utils.py | 6 +++--- python/sglang/srt/speculative/eagle_worker.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index cb9d86cf7..1db3448a1 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -89,14 +89,13 @@ class EagleDraftInput: cls, device: torch.device, hidden_size: int, + dtype: torch.dtype, topk: int, capture_hidden_mode: CaptureHiddenMode, ): return cls( verified_id=None, - hidden_states=torch.empty( - (0, hidden_size), device=device, dtype=torch.float32 - ), + hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), capture_hidden_mode=capture_hidden_mode, @@ -334,6 +333,7 @@ class EagleVerifyInput: draft_input=EagleDraftInput.create_idle_input( device=batch.device, hidden_size=batch.model_config.hidden_size, + dtype=batch.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ), diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 6d0482f46..c9e57702d 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -498,6 +498,7 @@ class EAGLEWorker(TpModelWorker): batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, hidden_size=self.model_config.hidden_size, + dtype=self.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) @@ -838,6 +839,7 @@ class EAGLEWorker(TpModelWorker): batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, hidden_size=self.model_config.hidden_size, + dtype=self.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, )