import logging from contextlib import contextmanager from typing import Optional import torch from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda if is_cuda(): from sgl_kernel import segment_packbits logger = logging.getLogger(__name__) RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") @contextmanager def draft_tp_context(tp_group: GroupCoordinator): # Draft model doesn't use dp and has its own tp group. # We disable mscclpp now because it doesn't support 2 comm groups. with patch_tensor_parallel_group(tp_group): yield class StandaloneWorker(EAGLEWorker): def __init__( self, server_args: ServerArgs, gpu_id: int, tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, nccl_port: int, target_worker: TpModelWorker, ): # Parse arguments self.server_args = server_args self.topk = server_args.speculative_eagle_topk self.speculative_num_steps = server_args.speculative_num_steps self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens self.enable_nan_detection = server_args.enable_nan_detection self.gpu_id = gpu_id self.device = server_args.device self.target_worker = target_worker self.page_size = server_args.page_size self.speculative_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) self.padded_static_len = -1 # Override the context length of the draft model to be the same as the target model. server_args.context_length = target_worker.model_runner.model_config.context_len # Do not capture cuda graph in `super().__init__()` # It will be captured later. backup_disable_cuda_graph = server_args.disable_cuda_graph server_args.disable_cuda_graph = True # Share the allocator with a target worker. # Draft and target worker own their own KV cache pools. self.req_to_token_pool, self.token_to_kv_pool_allocator = ( target_worker.get_memory_pool() ) # Load hot token ids if server_args.speculative_token_map is not None: self.hot_token_id = load_token_map(server_args.speculative_token_map) server_args.json_model_override_args = ( f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' ) else: self.hot_token_id = None # Init draft worker with empty_context(): TpModelWorker.__init__( self, server_args=server_args, gpu_id=gpu_id, tp_rank=tp_rank, pp_rank=0, # FIXME dp_rank=dp_rank, moe_ep_rank=moe_ep_rank, nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, ) # Init attention backend and cuda graphs self.draft_model_runner.server_args.disable_cuda_graph = ( backup_disable_cuda_graph ) self.draft_tp_context = ( draft_tp_context if server_args.enable_dp_attention else empty_context ) with self.draft_tp_context(self.draft_model_runner.tp_group): self.init_attention_backend() self.init_cuda_graphs() # Some dummy tensors self.num_new_pages_per_topk = torch.empty( (), dtype=torch.int64, device=self.device ) self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)