From 8c5930f08a2ebc5e44409cb815ec819304fae36e Mon Sep 17 00:00:00 2001 From: cicirori <32845984+cicirori@users.noreply.github.com> Date: Mon, 8 Sep 2025 06:44:36 +0200 Subject: [PATCH] Add speculator attention backend switch (#9981) --- .../layers/attention/hybrid_attn_backend.py | 108 ++++++++++-------- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/models/deepseek_v2.py | 9 ++ python/sglang/srt/server_args.py | 8 ++ python/sglang/srt/speculative/eagle_worker.py | 12 +- test/srt/test_hybrid_attn_backend.py | 46 ++++++++ 6 files changed, 130 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py index 30bbe6279..bf3918c70 100644 --- a/python/sglang/srt/layers/attention/hybrid_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py @@ -22,17 +22,45 @@ class HybridAttnBackend(AttentionBackend): self.prefill_backend = prefill_backend self.decode_backend = decode_backend - def init_forward_metadata(self, forward_batch: ForwardBatch): - if forward_batch.forward_mode.is_decode_or_idle(): - self.decode_backend.init_forward_metadata(forward_batch) + def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend: + """ + Select the appropriate attention backend based on the forward mode. + + Args: + forward_mode: The current forward mode indicating the operation type + + Returns: + The selected attention backend (prefill or decode) + + Note: + - decode_or_idle: Always uses decode backend + - target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend + - prefill: Always uses prefill backend + """ + if forward_mode.is_decode_or_idle(): + return self.decode_backend + elif forward_mode.is_target_verify() or forward_mode.is_draft_extend(): + return ( + self.decode_backend + if self.model_runner.server_args.speculative_attention_backend + == "decode" + else self.prefill_backend + ) else: - self.prefill_backend.init_forward_metadata(forward_batch) + return self.prefill_backend + + def init_forward_metadata(self, forward_batch: ForwardBatch): + backend = self._select_backend(forward_batch.forward_mode) + backend.init_forward_metadata(forward_batch) def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens) - if self.model_runner.server_args.speculative_algorithm is not None: - # When speculative decoding is enabled, we also need to initialize the - # prefill backend's cuda graph state to support target_verify. + if ( + self.model_runner.server_args.speculative_algorithm is not None + and self.model_runner.server_args.speculative_attention_backend == "prefill" + ): + # When speculative decoding is enabled, we need to initialize the backend + # that will be used for target_verify. self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens) def init_forward_metadata_capture_cuda_graph( @@ -45,26 +73,16 @@ class HybridAttnBackend(AttentionBackend): forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): - if forward_mode.is_decode_or_idle(): - self.decode_backend.init_forward_metadata_capture_cuda_graph( - bs, - num_tokens, - req_pool_indices, - seq_lens, - encoder_lens, - forward_mode, - spec_info, - ) - else: - self.prefill_backend.init_forward_metadata_capture_cuda_graph( - bs, - num_tokens, - req_pool_indices, - seq_lens, - encoder_lens, - forward_mode, - spec_info, - ) + backend = self._select_backend(forward_mode) + backend.init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) def init_forward_metadata_replay_cuda_graph( self, @@ -77,28 +95,17 @@ class HybridAttnBackend(AttentionBackend): spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], ): - if forward_mode.is_decode_or_idle(): - self.decode_backend.init_forward_metadata_replay_cuda_graph( - bs, - req_pool_indices, - seq_lens, - seq_lens_sum, - encoder_lens, - forward_mode, - spec_info, - seq_lens_cpu, - ) - else: - self.prefill_backend.init_forward_metadata_replay_cuda_graph( - bs, - req_pool_indices, - seq_lens, - seq_lens_sum, - encoder_lens, - forward_mode, - spec_info, - seq_lens_cpu, - ) + backend = self._select_backend(forward_mode) + backend.init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) def get_cuda_graph_seq_len_fill_value(self): return self.decode_backend.get_cuda_graph_seq_len_fill_value() @@ -127,6 +134,7 @@ class HybridAttnBackend(AttentionBackend): save_kv_cache: bool = True, **kwargs, ): - return self.prefill_backend.forward_extend( + backend = self._select_backend(forward_batch.forward_mode) + return backend.forward_extend( q, k, v, layer, forward_batch, save_kv_cache, **kwargs ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5dc5dce39..fb6009e5b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -98,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "sampling_backend", "speculative_accept_threshold_single", "speculative_accept_threshold_acc", + "speculative_attention_backend", "torchao_config", "triton_attention_reduce_in_fp32", "num_reserved_decode_tokens", diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 252d08d8b..06ebf7f78 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1045,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module): # Determine attention backend used by current forward batch if forward_batch.forward_mode.is_decode_or_idle(): attention_backend = global_server_args_dict["decode_attention_backend"] + elif ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ): + # Use the specified backend for speculative operations (both verify and draft extend) + if global_server_args_dict["speculative_attention_backend"] == "decode": + attention_backend = global_server_args_dict["decode_attention_backend"] + else: # default to prefill + attention_backend = global_server_args_dict["prefill_attention_backend"] else: attention_backend = global_server_args_dict["prefill_attention_backend"] self.current_attention_backend = attention_backend diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 04aba8f04..36d76f7ec 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -262,6 +262,7 @@ class ServerArgs: speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None + speculative_attention_backend: str = "prefill" # Expert parallelism ep_size: int = 1 @@ -1561,6 +1562,13 @@ class ServerArgs: help="The path of the draft model's small vocab table.", default=ServerArgs.speculative_token_map, ) + parser.add_argument( + "--speculative-attention-backend", + type=str, + choices=["prefill", "decode"], + help="Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'.", + default=ServerArgs.speculative_attention_backend, + ) # Expert parallelism parser.add_argument( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index daa5c30e0..45781aab2 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker): # Initialize decode attention backend self.draft_attn_backend = self._create_decode_backend() - # Initialize prefill attention backend + # Initialize draft extend attention backend (respects speculative_attention_backend setting) self.draft_extend_attn_backend = self._create_draft_extend_backend() self.draft_model_runner.draft_attn_backend = self.draft_attn_backend @@ -234,11 +234,15 @@ class EAGLEWorker(TpModelWorker): "trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend, } - + backend_name = ( + "decode_attention_backend" + if self.server_args.speculative_attention_backend == "decode" + else "prefill_attention_backend" + ) return self._create_backend( - "prefill_attention_backend", + backend_name, backend_map, - "EAGLE is not supported in prefill attention backend {backend_type}", + "EAGLE is not supported in attention backend {backend_type}", ) def _create_flashinfer_decode_backend(self): diff --git a/test/srt/test_hybrid_attn_backend.py b/test/srt/test_hybrid_attn_backend.py index 9251f34dc..306259df9 100644 --- a/test/srt/test_hybrid_attn_backend.py +++ b/test/srt/test_hybrid_attn_backend.py @@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase): ] +class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase): + speculative_decode = True + # This eagle test uses a very small model, so the accuracy is low. + accuracy_threshold = 0.2 + + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + [ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "2", + "--speculative-num-draft-tokens", + "4", + "--speculative-attention-backend", + "prefill", + ] + + +class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBackendBase): + speculative_decode = True + # This eagle test uses a very small model, so the accuracy is low. + accuracy_threshold = 0.2 + + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + [ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "2", + "--speculative-num-draft-tokens", + "4", + "--speculative-attention-backend", + "decode", + ] + + if __name__ == "__main__": unittest.main()