Files
sglang/python/sglang/srt/speculative/standalone_worker.py
2025-09-07 20:55:09 -07:00

110 lines
4.0 KiB
Python

import logging
from contextlib import contextmanager
from typing import Optional
import torch
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
if is_cuda():
from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__)
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
class StandaloneWorker(EAGLEWorker):
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
# Parse arguments
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
self.target_worker = target_worker
self.page_size = server_args.page_size
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.padded_static_len = -1
# Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len
# Do not capture cuda graph in `super().__init__()`
# It will be captured later.
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Load hot token ids
if server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
else:
self.hot_token_id = None
# Init draft worker
with empty_context():
TpModelWorker.__init__(
self,
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
pp_rank=0, # FIXME
dp_rank=dp_rank,
moe_ep_rank=moe_ep_rank,
nccl_port=nccl_port,
is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
# Init attention backend and cuda graphs
self.draft_model_runner.server_args.disable_cuda_graph = (
backup_disable_cuda_graph
)
self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.init_attention_backend()
self.init_cuda_graphs()
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)