[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -98,6 +99,8 @@ class ModelRunner:
|
||||
nccl_port: int,
|
||||
server_args: ServerArgs,
|
||||
is_draft_worker: bool = False,
|
||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||
):
|
||||
# Parse args
|
||||
self.model_config = model_config
|
||||
@@ -115,6 +118,8 @@ class ModelRunner:
|
||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
|
||||
# Model-specific adjustment
|
||||
if (
|
||||
@@ -257,8 +262,8 @@ class ModelRunner:
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||
|
||||
if self.device == "cuda":
|
||||
backend = "nccl"
|
||||
elif self.device == "xpu":
|
||||
@@ -660,12 +665,25 @@ class ModelRunner:
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
max_num_reqs = self.server_args.max_num_reqs
|
||||
else:
|
||||
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
|
||||
# can be concurrently allocated, so we should give a headroom for it.
|
||||
self.server_args.draft_runner_cache_size = (
|
||||
self.max_total_num_tokens
|
||||
+ max_num_reqs * self.server_args.speculative_num_steps
|
||||
# draft
|
||||
+ max_num_reqs
|
||||
* self.server_args.speculative_num_steps
|
||||
* self.server_args.speculative_eagle_topk
|
||||
# verify
|
||||
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
|
||||
# buffer
|
||||
+ 100
|
||||
)
|
||||
# Target worker and draft worker shares the same indices for the
|
||||
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
self.server_args.max_num_reqs = max_num_reqs
|
||||
|
||||
if max_total_tokens is not None:
|
||||
if max_total_tokens > self.max_total_num_tokens:
|
||||
@@ -681,12 +699,25 @@ class ModelRunner:
|
||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||
)
|
||||
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs + 1,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
if self.req_to_token_pool is None:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs + 1,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
else:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
|
||||
Reference in New Issue
Block a user