from __future__ import annotations import logging import threading import time from typing import TYPE_CHECKING, List, Optional, Tuple, Union from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch if TYPE_CHECKING: from sglang.srt.managers.scheduler import ( EmbeddingBatchResult, GenerationBatchResult, ScheduleBatch, Scheduler, ) logger = logging.getLogger(__name__) DEFAULT_FORCE_STREAM_INTERVAL = 50 class SchedulerOutputProcessorMixin: """ This class implements the output processing logic for Scheduler. We put them into a separate file to make the `scheduler.py` shorter. """ def process_batch_result_prefill( self: Scheduler, batch: ScheduleBatch, result: Union[GenerationBatchResult, EmbeddingBatchResult], launch_done: Optional[threading.Event] = None, ): skip_stream_req = None if self.is_generation: ( logits_output, next_token_ids, extend_input_len_per_req, extend_logprob_start_len_per_req, ) = ( result.logits_output, result.next_token_ids, result.extend_input_len_per_req, result.extend_logprob_start_len_per_req, ) if self.enable_overlap: logits_output, next_token_ids, _ = ( self.tp_worker.resolve_last_batch_result(launch_done) ) else: # Move next_token_ids and logprobs to cpu next_token_ids = next_token_ids.tolist() if batch.return_logprob: if logits_output.next_token_logprobs is not None: logits_output.next_token_logprobs = ( logits_output.next_token_logprobs.tolist() ) if logits_output.input_token_logprobs is not None: logits_output.input_token_logprobs = tuple( logits_output.input_token_logprobs.tolist() ) hidden_state_offset = 0 # Check finish conditions logprob_pt = 0 for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): if req.is_retracted: continue if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1]) continue if req.is_chunked <= 0: # req output_ids are set here req.output_ids.append(next_token_id) req.check_finished() if req.finished(): self.tree_cache.cache_finished_req(req) req.time_stats.completion_time = time.time() elif not batch.decoding_reqs or req not in batch.decoding_reqs: # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) if batch.return_logprob: assert extend_logprob_start_len_per_req is not None assert extend_input_len_per_req is not None extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_input_len = extend_input_len_per_req[i] num_input_logprobs = extend_input_len - extend_logprob_start_len if req.return_logprob: self.add_logprob_return_values( i, req, logprob_pt, next_token_ids, num_input_logprobs, logits_output, ) logprob_pt += num_input_logprobs if ( req.return_hidden_states and logits_output.hidden_states is not None ): req.hidden_states.append( logits_output.hidden_states[ hidden_state_offset : ( hidden_state_offset := hidden_state_offset + len(req.origin_input_ids) ) ] .cpu() .clone() .tolist() ) if req.grammar is not None: # FIXME: this try-except block is for handling unexpected xgrammar issue. try: req.grammar.accept_token(next_token_id) except ValueError as e: # Grammar accept_token can raise ValueError if the token is not in the grammar. # This can happen if the grammar is not set correctly or the token is invalid. logger.error( f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}" ) self.abort_request(AbortReq(req.rid)) req.grammar.finished = req.finished() else: # being chunked reqs' prefill is not finished req.is_chunked -= 1 # There is only at most one request being currently chunked. # Because this request does not finish prefill, # we don't want to stream the request currently being chunked. skip_stream_req = req # Incrementally update input logprobs. if batch.return_logprob: extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_input_len = extend_input_len_per_req[i] if extend_logprob_start_len < extend_input_len: # Update input logprobs. num_input_logprobs = ( extend_input_len - extend_logprob_start_len ) if req.return_logprob: self.add_input_logprob_return_values( i, req, logits_output, logprob_pt, num_input_logprobs, last_prefill_chunk=False, ) logprob_pt += num_input_logprobs self.set_next_batch_sampling_info_done(batch) else: # embedding or reward model embeddings, bid = result.embeddings, result.bid embeddings = embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): if req.is_retracted: continue req.embedding = embeddings[i] if req.is_chunked <= 0: # Dummy output token for embedding models req.output_ids.append(0) req.check_finished() if req.finished(): self.tree_cache.cache_finished_req(req) else: self.tree_cache.cache_unfinished_req(req) else: # being chunked reqs' prefill is not finished req.is_chunked -= 1 self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) def process_batch_result_decode( self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult, launch_done: Optional[threading.Event] = None, ): logits_output, next_token_ids, can_run_cuda_graph = ( result.logits_output, result.next_token_ids, result.can_run_cuda_graph, ) self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: logits_output, next_token_ids, can_run_cuda_graph = ( self.tp_worker.resolve_last_batch_result(launch_done) ) next_token_logprobs = logits_output.next_token_logprobs elif batch.spec_algorithm.is_none(): # spec decoding handles output logprobs inside verify process. next_token_ids = next_token_ids.tolist() if batch.return_logprob: next_token_logprobs = logits_output.next_token_logprobs.tolist() self.token_to_kv_pool_allocator.free_group_begin() # Check finish condition # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding. # We should ignore using next_token_ids for spec decoding cases. for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): if req.is_retracted: continue if self.enable_overlap and req.finished(): # Free the one extra delayed token if self.page_size == 1: self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) else: # Only free when the extra token is in a new page if ( len(req.origin_input_ids) + len(req.output_ids) - 1 ) % self.page_size == 0: self.token_to_kv_pool_allocator.free( batch.out_cache_loc[i : i + 1] ) continue if batch.spec_algorithm.is_none(): # speculative worker will solve the output_ids in speculative decoding req.output_ids.append(next_token_id) req.check_finished() if req.finished(): self.tree_cache.cache_finished_req(req) req.time_stats.completion_time = time.time() if req.return_logprob and batch.spec_algorithm.is_none(): # speculative worker handles logprob in speculative decoding req.output_token_logprobs_val.append(next_token_logprobs[i]) req.output_token_logprobs_idx.append(next_token_id) if req.top_logprobs_num > 0: req.output_top_logprobs_val.append( logits_output.next_token_top_logprobs_val[i] ) req.output_top_logprobs_idx.append( logits_output.next_token_top_logprobs_idx[i] ) if req.token_ids_logprob is not None: req.output_token_ids_logprobs_val.append( logits_output.next_token_token_ids_logprobs_val[i] ) req.output_token_ids_logprobs_idx.append( logits_output.next_token_token_ids_logprobs_idx[i] ) if req.return_hidden_states and logits_output.hidden_states is not None: req.hidden_states.append( logits_output.hidden_states[i].cpu().clone().tolist() ) if req.grammar is not None and batch.spec_algorithm.is_none(): # FIXME: this try-except block is for handling unexpected xgrammar issue. try: req.grammar.accept_token(next_token_id) except ValueError as e: # Grammar accept_token can raise ValueError if the token is not in the grammar. # This can happen if the grammar is not set correctly or the token is invalid. logger.error( f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}" ) self.abort_request(AbortReq(req.rid)) req.grammar.finished = req.finished() self.set_next_batch_sampling_info_done(batch) self.stream_output(batch.reqs, batch.return_logprob) self.token_to_kv_pool_allocator.free_group_end() self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) if ( self.current_scheduler_metrics_enabled() and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): self.log_decode_stats(can_run_cuda_graph, running_batch=batch) def add_input_logprob_return_values( self: Scheduler, i: int, req: Req, output: LogitsProcessorOutput, logprob_pt: int, num_input_logprobs: int, last_prefill_chunk: bool, # If True, it means prefill is finished. ): """Incrementally add input logprobs to `req`. Args: i: The request index in a batch. req: The request. Input logprobs inside req are modified as a consequence of the API fill_ids: The prefill ids processed. output: Logit processor output that's used to compute input logprobs last_prefill_chunk: True if it is the last prefill (when chunked). Some of input logprob operation should only happen at the last prefill (e.g., computing input token logprobs). """ assert output.input_token_logprobs is not None if req.input_token_logprobs is None: req.input_token_logprobs = [] if req.temp_input_top_logprobs_val is None: req.temp_input_top_logprobs_val = [] if req.temp_input_top_logprobs_idx is None: req.temp_input_top_logprobs_idx = [] if req.temp_input_token_ids_logprobs_val is None: req.temp_input_token_ids_logprobs_val = [] if req.temp_input_token_ids_logprobs_idx is None: req.temp_input_token_ids_logprobs_idx = [] if req.input_token_logprobs_val is not None: # The input logprob has been already computed. It only happens # upon retract. if req.top_logprobs_num > 0: assert req.input_token_logprobs_val is not None return # Important for the performance. assert isinstance(output.input_token_logprobs, tuple) input_token_logprobs: Tuple[int] = output.input_token_logprobs input_token_logprobs = input_token_logprobs[ logprob_pt : logprob_pt + num_input_logprobs ] req.input_token_logprobs.extend(input_token_logprobs) if req.top_logprobs_num > 0: req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i]) req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i]) if req.token_ids_logprob is not None: req.temp_input_token_ids_logprobs_val.append( output.input_token_ids_logprobs_val[i] ) req.temp_input_token_ids_logprobs_idx.append( output.input_token_ids_logprobs_idx[i] ) if last_prefill_chunk: input_token_logprobs = req.input_token_logprobs req.input_token_logprobs = None assert req.input_token_logprobs_val is None assert req.input_token_logprobs_idx is None assert req.input_top_logprobs_val is None assert req.input_top_logprobs_idx is None # Compute input_token_logprobs_val # Always pad the first one with None. req.input_token_logprobs_val = [None] req.input_token_logprobs_val.extend(input_token_logprobs) # The last input logprob is for sampling, so just pop it out. req.input_token_logprobs_val.pop() # Compute input_token_logprobs_idx input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ x if x < self.model_config.vocab_size - 1 else 0 for x in input_token_logprobs_idx ] req.input_token_logprobs_idx = input_token_logprobs_idx if req.top_logprobs_num > 0: req.input_top_logprobs_val = [None] req.input_top_logprobs_idx = [None] assert len(req.temp_input_token_ids_logprobs_val) == len( req.temp_input_token_ids_logprobs_idx ) for val, idx in zip( req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx, strict=True, ): req.input_top_logprobs_val.extend(val) req.input_top_logprobs_idx.extend(idx) # Last token is a sample token. req.input_top_logprobs_val.pop() req.input_top_logprobs_idx.pop() req.temp_input_top_logprobs_idx = None req.temp_input_top_logprobs_val = None if req.token_ids_logprob is not None: req.input_token_ids_logprobs_val = [None] req.input_token_ids_logprobs_idx = [None] for val, idx in zip( req.temp_input_token_ids_logprobs_val, req.temp_input_token_ids_logprobs_idx, strict=True, ): req.input_token_ids_logprobs_val.extend(val) req.input_token_ids_logprobs_idx.extend(idx) # Last token is a sample token. req.input_token_ids_logprobs_val.pop() req.input_token_ids_logprobs_idx.pop() req.temp_input_token_ids_logprobs_idx = None req.temp_input_token_ids_logprobs_val = None if req.return_logprob: relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len assert len(req.input_token_logprobs_val) == relevant_tokens_len assert len(req.input_token_logprobs_idx) == relevant_tokens_len if req.top_logprobs_num > 0: assert len(req.input_top_logprobs_val) == relevant_tokens_len assert len(req.input_top_logprobs_idx) == relevant_tokens_len if req.token_ids_logprob is not None: assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len def add_logprob_return_values( self: Scheduler, i: int, req: Req, pt: int, next_token_ids: List[int], num_input_logprobs: int, output: LogitsProcessorOutput, ): """Attach logprobs to the return values.""" req.output_token_logprobs_val.append(output.next_token_logprobs[i]) req.output_token_logprobs_idx.append(next_token_ids[i]) self.add_input_logprob_return_values( i, req, output, pt, num_input_logprobs, last_prefill_chunk=True ) if req.top_logprobs_num > 0: req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) if req.token_ids_logprob is not None: req.output_token_ids_logprobs_val.append( output.next_token_token_ids_logprobs_val[i] ) req.output_token_ids_logprobs_idx.append( output.next_token_token_ids_logprobs_idx[i] ) return num_input_logprobs def stream_output( self: Scheduler, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None, ): """Stream the output to detokenizer.""" if self.is_generation: self.stream_output_generation(reqs, return_logprob, skip_req) else: # embedding or reward model self.stream_output_embedding(reqs) def stream_output_generation( self: Scheduler, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None, ): rids = [] finished_reasons: List[BaseFinishReason] = [] decoded_texts = [] decode_ids_list = [] read_offsets = [] output_ids = [] skip_special_tokens = [] spaces_between_special_tokens = [] no_stop_trim = [] prompt_tokens = [] completion_tokens = [] cached_tokens = [] spec_verify_ct = [] output_hidden_states = None if return_logprob: input_token_logprobs_val = [] input_token_logprobs_idx = [] output_token_logprobs_val = [] output_token_logprobs_idx = [] input_top_logprobs_val = [] input_top_logprobs_idx = [] output_top_logprobs_val = [] output_top_logprobs_idx = [] input_token_ids_logprobs_val = [] input_token_ids_logprobs_idx = [] output_token_ids_logprobs_val = [] output_token_ids_logprobs_idx = [] else: input_token_logprobs_val = input_token_logprobs_idx = ( output_token_logprobs_val ) = output_token_logprobs_idx = input_top_logprobs_val = ( input_top_logprobs_idx ) = output_top_logprobs_val = output_top_logprobs_idx = ( input_token_ids_logprobs_val ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = ( output_token_ids_logprobs_idx ) = None for req in reqs: if req is skip_req: continue # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here. if self.model_config.is_multimodal_gen and req.to_abort: continue if req.finished(): if req.finished_output: # With the overlap schedule, a request will try to output twice and hit this line twice # because of the one additional delayed token. This "continue" prevented the dummy output. continue req.finished_output = True should_output = True else: if req.stream: stream_interval = ( req.sampling_params.stream_interval or self.stream_interval ) should_output = ( len(req.output_ids) % stream_interval == 1 if not self.model_config.is_multimodal_gen and stream_interval > 1 else len(req.output_ids) % stream_interval == 0 ) else: should_output = ( len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0 if not self.model_config.is_multimodal_gen else False ) if should_output: send_token_offset = req.send_token_offset send_output_token_logprobs_offset = ( req.send_output_token_logprobs_offset ) rids.append(req.rid) finished_reasons.append( req.finished_reason.to_json() if req.finished_reason else None ) decoded_texts.append(req.decoded_text) decode_ids, read_offset = req.init_incremental_detokenize() if self.model_config.is_multimodal_gen: decode_ids_list.append(decode_ids) else: decode_ids_list.append(decode_ids[req.send_decode_id_offset :]) req.send_decode_id_offset = len(decode_ids) read_offsets.append(read_offset) output_ids.append(req.output_ids[send_token_offset:]) req.send_token_offset = len(req.output_ids) skip_special_tokens.append(req.sampling_params.skip_special_tokens) spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens ) no_stop_trim.append(req.sampling_params.no_stop_trim) prompt_tokens.append(len(req.origin_input_ids)) completion_tokens.append(len(req.output_ids)) cached_tokens.append(req.cached_tokens) if not self.spec_algorithm.is_none(): spec_verify_ct.append(req.spec_verify_ct) if return_logprob: if ( req.return_logprob and not req.input_logprob_sent # Decode server does not send input logprobs and self.disaggregation_mode != DisaggregationMode.DECODE ): input_token_logprobs_val.append(req.input_token_logprobs_val) input_token_logprobs_idx.append(req.input_token_logprobs_idx) input_top_logprobs_val.append(req.input_top_logprobs_val) input_top_logprobs_idx.append(req.input_top_logprobs_idx) input_token_ids_logprobs_val.append( req.input_token_ids_logprobs_val ) input_token_ids_logprobs_idx.append( req.input_token_ids_logprobs_idx ) req.input_logprob_sent = True else: input_token_logprobs_val.append([]) input_token_logprobs_idx.append([]) input_top_logprobs_val.append([]) input_top_logprobs_idx.append([]) input_token_ids_logprobs_val.append([]) input_token_ids_logprobs_idx.append([]) if req.return_logprob: output_token_logprobs_val.append( req.output_token_logprobs_val[ send_output_token_logprobs_offset: ] ) output_token_logprobs_idx.append( req.output_token_logprobs_idx[ send_output_token_logprobs_offset: ] ) output_top_logprobs_val.append( req.output_top_logprobs_val[ send_output_token_logprobs_offset: ] ) output_top_logprobs_idx.append( req.output_top_logprobs_idx[ send_output_token_logprobs_offset: ] ) output_token_ids_logprobs_val.append( req.output_token_ids_logprobs_val[ send_output_token_logprobs_offset: ] ) output_token_ids_logprobs_idx.append( req.output_token_ids_logprobs_idx[ send_output_token_logprobs_offset: ] ) req.send_output_token_logprobs_offset = len( req.output_token_logprobs_val ) else: output_token_logprobs_val.append([]) output_token_logprobs_idx.append([]) output_top_logprobs_val.append([]) output_top_logprobs_idx.append([]) output_token_ids_logprobs_val.append([]) output_token_ids_logprobs_idx.append([]) if req.return_hidden_states: if output_hidden_states is None: output_hidden_states = [] output_hidden_states.append(req.hidden_states) if ( req.finished() and self.tp_rank == 0 and self.server_args.enable_request_time_stats_logging ): req.log_time_stats() # Send to detokenizer if rids: if self.model_config.is_multimodal_gen: return self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( rids, finished_reasons, decoded_texts, decode_ids_list, read_offsets, output_ids, skip_special_tokens, spaces_between_special_tokens, no_stop_trim, prompt_tokens, completion_tokens, cached_tokens, spec_verify_ct, input_token_logprobs_val, input_token_logprobs_idx, output_token_logprobs_val, output_token_logprobs_idx, input_top_logprobs_val, input_top_logprobs_idx, output_top_logprobs_val, output_top_logprobs_idx, input_token_ids_logprobs_val, input_token_ids_logprobs_idx, output_token_ids_logprobs_val, output_token_ids_logprobs_idx, output_hidden_states, ) ) def stream_output_embedding(self: Scheduler, reqs: List[Req]): rids = [] finished_reasons: List[BaseFinishReason] = [] embeddings = [] prompt_tokens = [] cached_tokens = [] for req in reqs: if req.finished(): rids.append(req.rid) finished_reasons.append(req.finished_reason.to_json()) embeddings.append(req.embedding) prompt_tokens.append(len(req.origin_input_ids)) cached_tokens.append(req.cached_tokens) self.send_to_detokenizer.send_pyobj( BatchEmbeddingOut( rids, finished_reasons, embeddings, prompt_tokens, cached_tokens ) )