Files
sglang/python/sglang/srt/managers/utils.py
Shangming Cai 384f8ab5ce [PD] Support PD disaggregation with Prefill PP (#8846)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: root <huzhiyuan@xiaohongshu.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com>
Co-authored-by: zitto <zhjc1124@gmail.com>
2025-08-16 18:31:31 -07:00

143 lines
5.6 KiB
Python

from __future__ import annotations
import logging
import multiprocessing as mp
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import GenerationBatchResult
logger = logging.getLogger(__name__)
def validate_input_length(
req: Req, max_req_input_len: int, allow_auto_truncate: bool
) -> Optional[str]:
"""Validate and potentially truncate input length.
Args:
req: The request containing input_ids to validate
max_req_input_len: Maximum allowed input length
allow_auto_truncate: Whether to truncate long inputs
Returns:
Error message if validation fails, None if successful
"""
if len(req.origin_input_ids) >= max_req_input_len:
if allow_auto_truncate:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
return None
else:
error_msg = (
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
f"the maximum allowed length ({max_req_input_len} tokens). "
f"Use a shorter input or enable --allow-auto-truncate."
)
return error_msg
return None
def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict:
logits_output = result.logits_output
assert logits_output is not None
return {
"extend_input_len_per_req": result.extend_input_len_per_req,
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
"next_token_logprobs": result.logits_output.next_token_logprobs,
"next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val,
"next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx,
"next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val,
"next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx,
"input_token_logprobs": result.logits_output.input_token_logprobs,
"input_top_logprobs_val": result.logits_output.input_top_logprobs_val,
"input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx,
"input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val,
"input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx,
}
def get_logprob_from_pp_outputs(
next_pp_outputs: PPProxyTensors,
) -> tuple[LogitsProcessorOutput, list[int], list[int]]:
logits_output = LogitsProcessorOutput(
# Do not send logits and hidden states because they are large
next_token_logits=None,
hidden_states=None,
next_token_logprobs=next_pp_outputs["next_token_logprobs"],
next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"],
next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"],
next_token_token_ids_logprobs_val=next_pp_outputs[
"next_token_token_ids_logprobs_val"
],
next_token_token_ids_logprobs_idx=next_pp_outputs[
"next_token_token_ids_logprobs_idx"
],
input_token_logprobs=next_pp_outputs["input_token_logprobs"],
input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"],
input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"],
input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"],
input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"],
)
extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"]
extend_logprob_start_len_per_req = next_pp_outputs[
"extend_logprob_start_len_per_req"
]
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
class DPBalanceMeta:
"""
This class will be use in scheduler and dp controller
"""
def __init__(self, num_workers: int):
self.num_workers = num_workers
self._manager = mp.Manager()
self.mutex = self._manager.Lock()
init_local_tokens = [0] * self.num_workers
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
self.shared_state = self._manager.Namespace()
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
def destructor(self):
# we must destructor this class manually
self._manager.shutdown()
def get_shared_onfly(self) -> List[Dict[int, int]]:
return [dict(d) for d in self.shared_state.onfly_info]
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
self.shared_state.onfly_info = data
def get_shared_local_tokens(self) -> List[int]:
return list(self.shared_state.local_tokens)
def set_shared_local_tokens(self, data: List[int]):
self.shared_state.local_tokens = data
def __getstate__(self):
state = self.__dict__.copy()
del state["_manager"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._manager = None