From 4a4772ae03c8b29834efbfa1175ba6abeafa77c9 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Thu, 28 Aug 2025 01:11:42 -0700 Subject: [PATCH] Support speculative decoding in hybrid attention backend (#9573) --- .../layers/attention/hybrid_attn_backend.py | 74 +++++++++++++------ .../sglang/srt/model_executor/model_runner.py | 4 +- test/srt/test_hybrid_attn_backend.py | 31 +++++++- 3 files changed, 83 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py index b9f829e41..30bbe6279 100644 --- a/python/sglang/srt/layers/attention/hybrid_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py @@ -5,6 +5,7 @@ import torch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput @@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend): """Support different backends for prefill and decode.""" def __init__( - self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend + self, + model_runner: ModelRunner, + prefill_backend: AttentionBackend, + decode_backend: AttentionBackend, ): + self.model_runner = model_runner self.prefill_backend = prefill_backend self.decode_backend = decode_backend def init_forward_metadata(self, forward_batch: ForwardBatch): - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_decode_or_idle(): self.decode_backend.init_forward_metadata(forward_batch) else: self.prefill_backend.init_forward_metadata(forward_batch) def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens) + if self.model_runner.server_args.speculative_algorithm is not None: + # When speculative decoding is enabled, we also need to initialize the + # prefill backend's cuda graph state to support target_verify. + self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens) def init_forward_metadata_capture_cuda_graph( self, @@ -36,15 +45,26 @@ class HybridAttnBackend(AttentionBackend): forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): - self.decode_backend.init_forward_metadata_capture_cuda_graph( - bs, - num_tokens, - req_pool_indices, - seq_lens, - encoder_lens, - forward_mode, - spec_info, - ) + if forward_mode.is_decode_or_idle(): + self.decode_backend.init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + else: + self.prefill_backend.init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) def init_forward_metadata_replay_cuda_graph( self, @@ -57,16 +77,28 @@ class HybridAttnBackend(AttentionBackend): spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], ): - self.decode_backend.init_forward_metadata_replay_cuda_graph( - bs, - req_pool_indices, - seq_lens, - seq_lens_sum, - encoder_lens, - forward_mode, - spec_info, - seq_lens_cpu, - ) + if forward_mode.is_decode_or_idle(): + self.decode_backend.init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + else: + self.prefill_backend.init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) def get_cuda_graph_seq_len_fill_value(self): return self.decode_backend.get_cuda_graph_seq_len_fill_value() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8d5b7c715..bbb0a3674 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1440,14 +1440,12 @@ class ModelRunner: else self.server_args.attention_backend ) if self.decode_attention_backend_str != self.prefill_attention_backend_str: - assert ( - self.server_args.speculative_algorithm is None - ), "Currently HybridAttentionBackend does not support speculative decoding." from sglang.srt.layers.attention.hybrid_attn_backend import ( HybridAttnBackend, ) attn_backend = HybridAttnBackend( + self, decode_backend=self._get_attention_backend_from_str( self.decode_attention_backend_str ), diff --git a/test/srt/test_hybrid_attn_backend.py b/test/srt/test_hybrid_attn_backend.py index 6791447f4..a527818fd 100644 --- a/test/srt/test_hybrid_attn_backend.py +++ b/test/srt/test_hybrid_attn_backend.py @@ -7,6 +7,8 @@ import requests from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase): base_url = DEFAULT_URL_FOR_TEST accuracy_threshold = 0.65 # derived tests need to override this speculative_decode = False - spec_decode_threshold = 1.0 # derived spec decoding tests need to override this + spec_decode_threshold = 2.2 # derived spec decoding tests need to override this @classmethod def get_server_args(cls): @@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase): # please don't do this if you want to make your inference workload faster os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + if cls.speculative_decode: + model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST + else: + model = cls.model cls.process = popen_launch_server( - cls.model, + model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=cls.get_server_args(), @@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase): return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"] +class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase): + speculative_decode = True + # This eagle test uses a very small model, so the accuracy is low. + accuracy_threshold = 0.2 + + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + [ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "2", + "--speculative-num-draft-tokens", + "4", + ] + + if __name__ == "__main__": unittest.main()