from __future__ import annotations import logging import multiprocessing as mp from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.model_executor.forward_batch_info import PPProxyTensors if TYPE_CHECKING: from sglang.srt.managers.scheduler import GenerationBatchResult logger = logging.getLogger(__name__) def validate_input_length( req: Req, max_req_input_len: int, allow_auto_truncate: bool ) -> Optional[str]: """Validate and potentially truncate input length. Args: req: The request containing input_ids to validate max_req_input_len: Maximum allowed input length allow_auto_truncate: Whether to truncate long inputs Returns: Error message if validation fails, None if successful """ if len(req.origin_input_ids) >= max_req_input_len: if allow_auto_truncate: logger.warning( "Request length is longer than the KV cache pool size or " "the max context length. Truncated. " f"{len(req.origin_input_ids)=}, {max_req_input_len=}." ) req.origin_input_ids = req.origin_input_ids[:max_req_input_len] return None else: error_msg = ( f"Input length ({len(req.origin_input_ids)} tokens) exceeds " f"the maximum allowed length ({max_req_input_len} tokens). " f"Use a shorter input or enable --allow-auto-truncate." ) return error_msg return None def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict: logits_output = result.logits_output assert logits_output is not None return { "extend_input_len_per_req": result.extend_input_len_per_req, "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, "next_token_logprobs": result.logits_output.next_token_logprobs, "next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val, "next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx, "next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val, "next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx, "input_token_logprobs": result.logits_output.input_token_logprobs, "input_top_logprobs_val": result.logits_output.input_top_logprobs_val, "input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx, "input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val, "input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx, } def get_logprob_from_pp_outputs( next_pp_outputs: PPProxyTensors, ) -> tuple[LogitsProcessorOutput, list[int], list[int]]: logits_output = LogitsProcessorOutput( # Do not send logits and hidden states because they are large next_token_logits=None, hidden_states=None, next_token_logprobs=next_pp_outputs["next_token_logprobs"], next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"], next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"], next_token_token_ids_logprobs_val=next_pp_outputs[ "next_token_token_ids_logprobs_val" ], next_token_token_ids_logprobs_idx=next_pp_outputs[ "next_token_token_ids_logprobs_idx" ], input_token_logprobs=next_pp_outputs["input_token_logprobs"], input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"], input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"], input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"], input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"], ) extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"] extend_logprob_start_len_per_req = next_pp_outputs[ "extend_logprob_start_len_per_req" ] return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req class DPBalanceMeta: """ This class will be use in scheduler and dp controller """ def __init__(self, num_workers: int): self.num_workers = num_workers self._manager = mp.Manager() self.mutex = self._manager.Lock() init_local_tokens = [0] * self.num_workers init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)] self.shared_state = self._manager.Namespace() self.shared_state.local_tokens = self._manager.list(init_local_tokens) self.shared_state.onfly_info = self._manager.list(init_onfly_info) def destructor(self): # we must destructor this class manually self._manager.shutdown() def get_shared_onfly(self) -> List[Dict[int, int]]: return [dict(d) for d in self.shared_state.onfly_info] def set_shared_onfly_info(self, data: List[Dict[int, int]]): self.shared_state.onfly_info = data def get_shared_local_tokens(self) -> List[int]: return list(self.shared_state.local_tokens) def set_shared_local_tokens(self, data: List[int]): self.shared_state.local_tokens = data def __getstate__(self): state = self.__dict__.copy() del state["_manager"] return state def __setstate__(self, state): self.__dict__.update(state) self._manager = None