Standalone speculative decoding (#10090)

This commit is contained in:
Qiaolin Yu
2025-09-07 20:55:09 -07:00
committed by GitHub
parent 400d3b97ae
commit 8cda5a622c
11 changed files with 285 additions and 9 deletions

View File

@@ -1539,7 +1539,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.forward_mode = ForwardMode.DECODE
bs = len(self.reqs)
if self.spec_algorithm.is_eagle():
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return

View File

@@ -349,6 +349,18 @@ class Scheduler(
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_standalone():
from sglang.srt.speculative.standalone_worker import StandaloneWorker
self.draft_worker = StandaloneWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None

View File

@@ -271,7 +271,10 @@ class CudaGraphRunner:
self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle():
if (
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
):
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
else:
@@ -827,7 +830,10 @@ class CudaGraphRunner:
def get_spec_info(self, num_tokens: int):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
if (
self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
):
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
if self.model_runner.is_draft_worker:

View File

@@ -473,9 +473,14 @@ class ServerArgs:
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
reserved_mem = 32 * 1024
# draft model and larger cuda graph buffers
if self.speculative_algorithm is not None:
# draft model and larger cuda graph buffers
reserved_mem += 2 * 1024
if self.speculative_algorithm == "STANDALONE":
# Standalone speculative decoding needs more memory than other speculative
# decoding algorithms since the draft model is typically larger.
reserved_mem += 6 * 1024
else:
reserved_mem += 2 * 1024
if self.enable_dp_attention:
reserved_mem += 4 * 1024
@@ -704,7 +709,12 @@ class ServerArgs:
# NEXTN shares the same implementation of EAGLE
self.speculative_algorithm = "EAGLE"
if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"):
if self.speculative_algorithm == "STANDALONE":
# TODO: support dp attention for standalone speculative decoding
assert (
self.enable_dp_attention is False
), "Currently standalone speculative decoding does not support dp attention."
if self.max_running_requests is None:
self.max_running_requests = 48
self.disable_overlap_schedule = True
@@ -1499,7 +1509,7 @@ class ServerArgs:
parser.add_argument(
"--speculative-algorithm",
type=str,
choices=["EAGLE", "EAGLE3", "NEXTN"],
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE"],
help="Speculative algorithm.",
)
parser.add_argument(
@@ -2635,7 +2645,9 @@ def auto_choose_speculative_params(self: ServerArgs):
"""
hf_config = self.get_hf_config()
arch = hf_config.architectures[0]
if self.speculative_algorithm == "STANDALONE":
# The default value for standalone speculative decoding
return (3, 1, 4)
if arch in ["LlamaForCausalLM"]:
# The default value for llama
return (5, 4, 8)

View File

@@ -341,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
self.positions[:num_tokens].copy_(forward_batch.positions)
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
if (
forward_batch.spec_info.hidden_states.shape[1]
== self.hidden_states.shape[1]
):
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
if forward_batch.spec_info.accept_length is not None:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)

View File

@@ -730,6 +730,14 @@ class EAGLEWorker(TpModelWorker):
# Set inputs
forward_batch.input_ids = input_ids
# This is a temporary fix for the case that the user is using standalone
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
# rope kernel needs cache_loc to be contiguous.
if (
self.server_args.speculative_algorithm == "STANDALONE"
and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
):
out_cache_loc = out_cache_loc.contiguous()
forward_batch.out_cache_loc = out_cache_loc[i]
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]

View File

@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
NONE = auto()
EAGLE = auto()
EAGLE3 = auto()
STANDALONE = auto()
def is_none(self):
return self == SpeculativeAlgorithm.NONE
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
def is_eagle3(self):
return self == SpeculativeAlgorithm.EAGLE3
def is_standalone(self):
return self == SpeculativeAlgorithm.STANDALONE
@staticmethod
def from_string(name: str):
name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE,
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
None: SpeculativeAlgorithm.NONE,
}
if name is not None:

View File

@@ -0,0 +1,109 @@
import logging
from contextlib import contextmanager
from typing import Optional
import torch
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
if is_cuda():
from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__)
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
class StandaloneWorker(EAGLEWorker):
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
# Parse arguments
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
self.target_worker = target_worker
self.page_size = server_args.page_size
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.padded_static_len = -1
# Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len
# Do not capture cuda graph in `super().__init__()`
# It will be captured later.
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Load hot token ids
if server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
else:
self.hot_token_id = None
# Init draft worker
with empty_context():
TpModelWorker.__init__(
self,
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
pp_rank=0, # FIXME
dp_rank=dp_rank,
moe_ep_rank=moe_ep_rank,
nccl_port=nccl_port,
is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
# Init attention backend and cuda graphs
self.draft_model_runner.server_args.disable_cuda_graph = (
backup_disable_cuda_graph
)
self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.init_attention_backend()
self.init_cuda_graphs()
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)

View File

@@ -72,6 +72,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct"
)
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
# Other use cases
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (