From 8cda5a622c4502eac9181d1019e2ad6c56046af4 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Sun, 7 Sep 2025 20:55:09 -0700 Subject: [PATCH] Standalone speculative decoding (#10090) --- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/managers/scheduler.py | 12 ++ .../srt/model_executor/cuda_graph_runner.py | 10 +- python/sglang/srt/server_args.py | 22 +++- .../eagle_draft_extend_cuda_graph_runner.py | 6 +- python/sglang/srt/speculative/eagle_worker.py | 8 ++ python/sglang/srt/speculative/spec_info.py | 5 + .../srt/speculative/standalone_worker.py | 109 +++++++++++++++++ python/sglang/test/test_utils.py | 4 + test/srt/run_suite.py | 1 + .../test_standalone_speculative_decoding.py | 115 ++++++++++++++++++ 11 files changed, 285 insertions(+), 9 deletions(-) create mode 100644 python/sglang/srt/speculative/standalone_worker.py create mode 100644 test/srt/test_standalone_speculative_decoding.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fdef179a1..aff5eacc1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1539,7 +1539,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.forward_mode = ForwardMode.DECODE bs = len(self.reqs) - if self.spec_algorithm.is_eagle(): + if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): # if spec decoding is used, the decode batch is prepared inside # `forward_batch_speculative_generation` after running draft models. return diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8daa8afe2..807c4eda9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -349,6 +349,18 @@ class Scheduler( target_worker=self.tp_worker, dp_rank=dp_rank, ) + elif self.spec_algorithm.is_standalone(): + from sglang.srt.speculative.standalone_worker import StandaloneWorker + + self.draft_worker = StandaloneWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) else: self.draft_worker = None diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 8413b164b..14da84e42 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -271,7 +271,10 @@ class CudaGraphRunner: self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 - if model_runner.spec_algorithm.is_eagle(): + if ( + model_runner.spec_algorithm.is_eagle() + or model_runner.spec_algorithm.is_standalone() + ): if self.model_runner.is_draft_worker: raise RuntimeError("This should not happen") else: @@ -827,7 +830,10 @@ class CudaGraphRunner: def get_spec_info(self, num_tokens: int): spec_info = None - if self.model_runner.spec_algorithm.is_eagle(): + if ( + self.model_runner.spec_algorithm.is_eagle() + or self.model_runner.spec_algorithm.is_standalone() + ): from sglang.srt.speculative.eagle_utils import EagleVerifyInput if self.model_runner.is_draft_worker: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c7f5a69a1..04aba8f04 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -473,9 +473,14 @@ class ServerArgs: # B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512) reserved_mem = 32 * 1024 + # draft model and larger cuda graph buffers if self.speculative_algorithm is not None: - # draft model and larger cuda graph buffers - reserved_mem += 2 * 1024 + if self.speculative_algorithm == "STANDALONE": + # Standalone speculative decoding needs more memory than other speculative + # decoding algorithms since the draft model is typically larger. + reserved_mem += 6 * 1024 + else: + reserved_mem += 2 * 1024 if self.enable_dp_attention: reserved_mem += 4 * 1024 @@ -704,7 +709,12 @@ class ServerArgs: # NEXTN shares the same implementation of EAGLE self.speculative_algorithm = "EAGLE" - if self.speculative_algorithm in ("EAGLE", "EAGLE3"): + if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): + if self.speculative_algorithm == "STANDALONE": + # TODO: support dp attention for standalone speculative decoding + assert ( + self.enable_dp_attention is False + ), "Currently standalone speculative decoding does not support dp attention." if self.max_running_requests is None: self.max_running_requests = 48 self.disable_overlap_schedule = True @@ -1499,7 +1509,7 @@ class ServerArgs: parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "EAGLE3", "NEXTN"], + choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE"], help="Speculative algorithm.", ) parser.add_argument( @@ -2635,7 +2645,9 @@ def auto_choose_speculative_params(self: ServerArgs): """ hf_config = self.get_hf_config() arch = hf_config.architectures[0] - + if self.speculative_algorithm == "STANDALONE": + # The default value for standalone speculative decoding + return (3, 1, 4) if arch in ["LlamaForCausalLM"]: # The default value for llama return (5, 4, 8) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 18ab617bd..8340b0ca8 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -341,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner: self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) self.positions[:num_tokens].copy_(forward_batch.positions) - self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) + if ( + forward_batch.spec_info.hidden_states.shape[1] + == self.hidden_states.shape[1] + ): + self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) if forward_batch.spec_info.accept_length is not None: self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 2c3940943..daa5c30e0 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -730,6 +730,14 @@ class EAGLEWorker(TpModelWorker): # Set inputs forward_batch.input_ids = input_ids + # This is a temporary fix for the case that the user is using standalone + # speculative decoding and the draft model architecture is gpt-oss. gpt-oss + # rope kernel needs cache_loc to be contiguous. + if ( + self.server_args.speculative_algorithm == "STANDALONE" + and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM" + ): + out_cache_loc = out_cache_loc.contiguous() forward_batch.out_cache_loc = out_cache_loc[i] forward_batch.positions.add_(1) forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index af556b99c..a80963471 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum): NONE = auto() EAGLE = auto() EAGLE3 = auto() + STANDALONE = auto() def is_none(self): return self == SpeculativeAlgorithm.NONE @@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum): def is_eagle3(self): return self == SpeculativeAlgorithm.EAGLE3 + def is_standalone(self): + return self == SpeculativeAlgorithm.STANDALONE + @staticmethod def from_string(name: str): name_map = { "EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE3": SpeculativeAlgorithm.EAGLE3, + "STANDALONE": SpeculativeAlgorithm.STANDALONE, None: SpeculativeAlgorithm.NONE, } if name is not None: diff --git a/python/sglang/srt/speculative/standalone_worker.py b/python/sglang/srt/speculative/standalone_worker.py new file mode 100644 index 000000000..b6004ea01 --- /dev/null +++ b/python/sglang/srt/speculative/standalone_worker.py @@ -0,0 +1,109 @@ +import logging +from contextlib import contextmanager +from typing import Optional + +import torch + +from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda + +if is_cuda(): + from sgl_kernel import segment_packbits + +logger = logging.getLogger(__name__) +RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") + + +@contextmanager +def draft_tp_context(tp_group: GroupCoordinator): + # Draft model doesn't use dp and has its own tp group. + # We disable mscclpp now because it doesn't support 2 comm groups. + with patch_tensor_parallel_group(tp_group): + yield + + +class StandaloneWorker(EAGLEWorker): + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + # Parse arguments + self.server_args = server_args + self.topk = server_args.speculative_eagle_topk + self.speculative_num_steps = server_args.speculative_num_steps + self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens + self.enable_nan_detection = server_args.enable_nan_detection + self.gpu_id = gpu_id + self.device = server_args.device + self.target_worker = target_worker + self.page_size = server_args.page_size + self.speculative_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) + self.padded_static_len = -1 + + # Override the context length of the draft model to be the same as the target model. + server_args.context_length = target_worker.model_runner.model_config.context_len + + # Do not capture cuda graph in `super().__init__()` + # It will be captured later. + backup_disable_cuda_graph = server_args.disable_cuda_graph + server_args.disable_cuda_graph = True + # Share the allocator with a target worker. + # Draft and target worker own their own KV cache pools. + self.req_to_token_pool, self.token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + + # Load hot token ids + if server_args.speculative_token_map is not None: + self.hot_token_id = load_token_map(server_args.speculative_token_map) + server_args.json_model_override_args = ( + f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' + ) + else: + self.hot_token_id = None + + # Init draft worker + with empty_context(): + TpModelWorker.__init__( + self, + server_args=server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + pp_rank=0, # FIXME + dp_rank=dp_rank, + moe_ep_rank=moe_ep_rank, + nccl_port=nccl_port, + is_draft_worker=True, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + ) + + # Init attention backend and cuda graphs + self.draft_model_runner.server_args.disable_cuda_graph = ( + backup_disable_cuda_graph + ) + self.draft_tp_context = ( + draft_tp_context if server_args.enable_dp_attention else empty_context + ) + with self.draft_tp_context(self.draft_model_runner.tp_group): + self.init_attention_backend() + self.init_cuda_graphs() + + # Some dummy tensors + self.num_new_pages_per_topk = torch.empty( + (), dtype=torch.int64, device=self.device + ) + self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 953fb76df..bd962a7f8 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -72,6 +72,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8" DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B" DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" +DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = ( + "meta-llama/Llama-3.1-8B-Instruct" +) +DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" # Other use cases DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = ( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f4e5871de..f9f77ecdd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -76,6 +76,7 @@ suites = { TestFile("test_harmony_parser.py", 20), TestFile("test_hidden_states.py", 55), TestFile("test_hybrid_attn_backend.py", 100), + TestFile("test_standalone_speculative_decoding.py", 250), TestFile("test_input_embeddings.py", 38), TestFile("test_io_struct.py", 8), TestFile("test_jinja_template_utils.py", 1), diff --git a/test/srt/test_standalone_speculative_decoding.py b/test/srt/test_standalone_speculative_decoding.py new file mode 100644 index 000000000..e2962b716 --- /dev/null +++ b/test/srt/test_standalone_speculative_decoding.py @@ -0,0 +1,115 @@ +import os +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST, + DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +GSM_DATASET_PATH = None + + +# Default server arguments shared across all tests +DEFAULT_SERVER_ARGS = [ + "--trust-remote-code", + "--cuda-graph-max-bs", + "8", + "--speculative-algorithm", + "STANDALONE", + "--speculative-draft-model-path", + DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "4", + "--speculative-eagle-topk", + "2", + "--speculative-num-draft-tokens", + "7", + "--mem-fraction-static", + 0.7, +] + + +class TestStandaloneSpeculativeDecodingBase(CustomTestCase): + + model = DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST + draft_model = DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + accuracy_threshold = 0.7 # derived tests need to override this + spec_decode_threshold = 3.6 # derived spec decoding tests need to override this + + @classmethod + def get_server_args(cls): + """Return the arguments for the server launch. Override in subclasses.""" + return DEFAULT_SERVER_ARGS + ["--attention-backend", "fa3"] + + @classmethod + def setUpClass(cls): + # disable deep gemm precompile to make launch server faster + # please don't do this if you want to make your inference workload faster + os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + model = cls.model + cls.process = popen_launch_server( + model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.get_server_args(), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=4, + num_questions=100, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + data_path=GSM_DATASET_PATH, + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + # Use the appropriate metric key based on the test class + metric_key = "accuracy" + self.assertGreater(metrics[metric_key], self.accuracy_threshold) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold) + + +class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase): + + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"] + + +class TestStandaloneSpeculativeDecodingFlashinfer( + TestStandaloneSpeculativeDecodingBase +): + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"] + + +if __name__ == "__main__": + unittest.main()