[Eagle] Refactor eagle speculative decoding (#3986)

Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
Ying Sheng
2025-03-05 08:06:07 -08:00
committed by GitHub
parent 5be8f1ed98
commit d3d4d76758
22 changed files with 670 additions and 352 deletions

View File

@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -98,6 +99,8 @@ class ModelRunner:
nccl_port: int,
server_args: ServerArgs,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
):
# Parse args
self.model_config = model_config
@@ -115,6 +118,8 @@ class ModelRunner:
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
# Model-specific adjustment
if (
@@ -257,8 +262,8 @@ class ModelRunner:
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda":
backend = "nccl"
elif self.device == "xpu":
@@ -660,12 +665,25 @@ class ModelRunner:
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
max_num_reqs = self.server_args.max_num_reqs
else:
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
# can be concurrently allocated, so we should give a headroom for it.
self.server_args.draft_runner_cache_size = (
self.max_total_num_tokens
+ max_num_reqs * self.server_args.speculative_num_steps
# draft
+ max_num_reqs
* self.server_args.speculative_num_steps
* self.server_args.speculative_eagle_topk
# verify
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
# buffer
+ 100
)
# Target worker and draft worker shares the same indices for the
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
self.server_args.max_num_reqs = max_num_reqs
if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens:
@@ -681,12 +699,25 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if self.req_to_token_pool is None:
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
else:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
)
else:
assert self.is_draft_worker
if (
self.model_config.attention_arch == AttentionArch.MLA