diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 04e3b962d..873fa8b05 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -209,6 +209,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--speculative-accept-threshold-single` | Accept a draft token if its probability in the target model is greater than this threshold. | 1.0 | | `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | 1.0 | | `--speculative-token-map` | The path of the draft model's small vocab table. | None | +| `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | Prefill | ## Expert parallelism diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py index bf3918c70..580a977ec 100644 --- a/python/sglang/srt/layers/attention/hybrid_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py @@ -34,7 +34,7 @@ class HybridAttnBackend(AttentionBackend): Note: - decode_or_idle: Always uses decode backend - - target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend + - target_verify or draft_extend: Uses decode backend if speculative_attention_mode is "decode", otherwise prefill backend - prefill: Always uses prefill backend """ if forward_mode.is_decode_or_idle(): @@ -42,8 +42,7 @@ class HybridAttnBackend(AttentionBackend): elif forward_mode.is_target_verify() or forward_mode.is_draft_extend(): return ( self.decode_backend - if self.model_runner.server_args.speculative_attention_backend - == "decode" + if self.model_runner.server_args.speculative_attention_mode == "decode" else self.prefill_backend ) else: @@ -57,7 +56,7 @@ class HybridAttnBackend(AttentionBackend): self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens) if ( self.model_runner.server_args.speculative_algorithm is not None - and self.model_runner.server_args.speculative_attention_backend == "prefill" + and self.model_runner.server_args.speculative_attention_mode == "prefill" ): # When speculative decoding is enabled, we need to initialize the backend # that will be used for target_verify. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fb6009e5b..df5ade906 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -98,7 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "sampling_backend", "speculative_accept_threshold_single", "speculative_accept_threshold_acc", - "speculative_attention_backend", + "speculative_attention_mode", "torchao_config", "triton_attention_reduce_in_fp32", "num_reserved_decode_tokens", diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 06ebf7f78..168ad9f29 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1050,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module): or forward_batch.forward_mode.is_draft_extend() ): # Use the specified backend for speculative operations (both verify and draft extend) - if global_server_args_dict["speculative_attention_backend"] == "decode": + if global_server_args_dict["speculative_attention_mode"] == "decode": attention_backend = global_server_args_dict["decode_attention_backend"] else: # default to prefill attention_backend = global_server_args_dict["prefill_attention_backend"] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 36d76f7ec..efe690750 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -262,7 +262,7 @@ class ServerArgs: speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None - speculative_attention_backend: str = "prefill" + speculative_attention_mode: str = "prefill" # Expert parallelism ep_size: int = 1 @@ -1563,11 +1563,11 @@ class ServerArgs: default=ServerArgs.speculative_token_map, ) parser.add_argument( - "--speculative-attention-backend", + "--speculative-attention-mode", type=str, choices=["prefill", "decode"], - help="Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'.", - default=ServerArgs.speculative_attention_backend, + help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.", + default=ServerArgs.speculative_attention_mode, ) # Expert parallelism diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 45781aab2..3ca2f464e 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker): # Initialize decode attention backend self.draft_attn_backend = self._create_decode_backend() - # Initialize draft extend attention backend (respects speculative_attention_backend setting) + # Initialize draft extend attention backend (respects speculative_attention_mode setting) self.draft_extend_attn_backend = self._create_draft_extend_backend() self.draft_model_runner.draft_attn_backend = self.draft_attn_backend @@ -236,7 +236,7 @@ class EAGLEWorker(TpModelWorker): } backend_name = ( "decode_attention_backend" - if self.server_args.speculative_attention_backend == "decode" + if self.server_args.speculative_attention_mode == "decode" else "prefill_attention_backend" ) return self._create_backend( diff --git a/test/srt/test_hybrid_attn_backend.py b/test/srt/test_hybrid_attn_backend.py index cd93f434d..1574ff873 100644 --- a/test/srt/test_hybrid_attn_backend.py +++ b/test/srt/test_hybrid_attn_backend.py @@ -111,27 +111,6 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase): return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"] -class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase): - speculative_decode = True - # This eagle test uses a very small model, so the accuracy is low. - accuracy_threshold = 0.2 - - @classmethod - def get_server_args(cls): - return DEFAULT_SERVER_ARGS + [ - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model-path", - DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - "--speculative-num-steps", - "3", - "--speculative-eagle-topk", - "2", - "--speculative-num-draft-tokens", - "4", - ] - - class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase): speculative_decode = True # This eagle test uses a very small model, so the accuracy is low. @@ -150,7 +129,7 @@ class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBacke "2", "--speculative-num-draft-tokens", "4", - "--speculative-attention-backend", + "--speculative-attention-mode", "prefill", ] @@ -173,7 +152,7 @@ class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBacken "2", "--speculative-num-draft-tokens", "4", - "--speculative-attention-backend", + "--speculative-attention-mode", "decode", ]