Files
sglang/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
2025-09-28 00:09:38 +08:00

171 lines
6.7 KiB
Python

from __future__ import annotations
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING
import torch
from sglang.srt.disaggregation.utils import prepare_abort
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.server_args import ServerArgs
class ScheduleBatchDisaggregationDecodeMixin:
def prepare_for_prebuilt_extend(self: ScheduleBatch):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self.forward_mode = ForwardMode.EXTEND
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
pre_lens = []
req_pool_indices = []
# Pre-calculate total size
total_size = sum(req.extend_input_len for req in reqs)
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
# Fill the tensor in one pass
offset = 0
for i, req in enumerate(reqs):
req_pool_indices.append(req.req_pool_idx)
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: req.extend_input_len
]
assert (
offset + req.extend_input_len <= total_size
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
out_cache_loc[offset : offset + req.extend_input_len] = chunk
offset += req.extend_input_len
pre_len = len(req.prefix_indices)
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
seq_lens.append(seq_len)
if len(req.output_ids) == 0:
assert (
seq_len - pre_len == req.extend_input_len
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
req.extend_logprob_start_len = 0
extend_input_logprob_token_ids = None
# Set fields
self.input_ids = torch.tensor(
sum(input_ids, []), dtype=torch.int32, device=self.device
)
self.req_pool_indices = torch.tensor(
req_pool_indices, dtype=torch.int64, device=self.device
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.orig_seq_lens = torch.tensor(
seq_lens, dtype=torch.int32, device=self.device
)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
self.multimodal_inputs = [r.multimodal_inputs for r in reqs]
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
)
def process_prebuilt_extend(
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self.output_ids = []
for req in self.reqs:
self.output_ids.append(req.output_ids[-1])
self.tree_cache.cache_unfinished_req(req)
if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue.
try:
# if it is not None, then the grammar is from a retracted request, and we should not
# accept the token as it's already accepted
if req.grammar.current_token is None:
req.grammar.accept_token(req.output_ids[-1])
except ValueError as e:
# Grammar accept_token can raise ValueError if the token is not in the grammar.
# This can happen if the grammar is not set correctly or the token is invalid.
error_message = f"Grammar accept_token failed for req {req.rid} with token {req.output_ids[-1]}: {e}"
self.tree_cache.cache_finished_req(req)
prepare_abort(
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
req.grammar.finished = req.finished()
self.output_ids = torch.tensor(self.output_ids, device=self.device)
# Simulate the eagle run.
if self.spec_algorithm.is_eagle():
b = len(self.reqs)
topk = server_args.speculative_eagle_topk
topk_p = torch.stack(
[
torch.as_tensor(
req.output_topk_p[:topk],
device=self.device,
dtype=torch.float32,
)
for req in self.reqs
],
dim=0,
)
topk_index = torch.stack(
[
torch.as_tensor(
req.output_topk_index[:topk],
device=self.device,
dtype=torch.int64,
)
for req in self.reqs
],
dim=0,
)
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
# local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
verified_id=self.output_ids,
)
spec_info.prepare_for_extend(self)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
self.spec_info = spec_info