From e35a93fa8a871c02db4c0dd5b58918d6774fb47a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 12 Mar 2025 16:21:49 -0700 Subject: [PATCH] Move output processing logic from scheduler.py into a separate file (#4354) --- python/sglang/srt/layers/sampler.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 22 - python/sglang/srt/managers/scheduler.py | 580 +---------------- .../scheduler_output_processor_mixin.py | 602 ++++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 19 +- python/sglang/srt/server_args.py | 18 +- 6 files changed, 634 insertions(+), 609 deletions(-) create mode 100644 python/sglang/srt/managers/scheduler_output_processor_mixin.py diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index ec041305c..37f22ec21 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import List import torch import torch.distributed as dist diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0ac870767..f4667f574 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -441,28 +441,6 @@ class Req: all_ids = self.origin_input_ids_unpadded + self.output_ids return all_ids[self.surr_offset :], self.read_offset - self.surr_offset - def get_next_inc_detokenization(self): - if self.tokenizer is None: - return False, "" - read_ids, read_offset = self.init_incremental_detokenize() - surr_ids = read_ids[:read_offset] - - surr_text = self.tokenizer.decode( - surr_ids, - skip_special_tokens=self.sampling_params.skip_special_tokens, - spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens, - ) - new_text = self.tokenizer.decode( - read_ids, - skip_special_tokens=self.sampling_params.skip_special_tokens, - spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens, - ) - - if len(new_text) > len(surr_text) and not new_text.endswith("�"): - return True, new_text[len(surr_text) :] - - return False, "" - def check_finished(self): if self.finished(): return diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index af0bb825f..9c42c29f0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, - BatchEmbeddingOut, - BatchTokenIDOut, CloseSessionReqInput, FlushCacheReq, GetInternalStateReq, @@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, - BaseFinishReason, ImageInputs, Req, ScheduleBatch, @@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import ( PrefillAdder, SchedulePolicy, ) +from sglang.srt.managers.scheduler_output_processor_mixin import ( + SchedulerOutputProcessorMixin, +) from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient @@ -132,7 +132,7 @@ class EmbeddingBatchResult: bid: int -class Scheduler: +class Scheduler(SchedulerOutputProcessorMixin): """A scheduler that manages a tensor parallel GPU worker.""" def __init__( @@ -1256,578 +1256,6 @@ class Scheduler: self.return_health_check_ct -= 1 self.send_to_tokenizer.send_pyobj(HealthCheckOutput()) - def process_batch_result_prefill( - self, - batch: ScheduleBatch, - result: Union[GenerationBatchResult, EmbeddingBatchResult], - ): - skip_stream_req = None - - if self.is_generation: - ( - logits_output, - next_token_ids, - extend_input_len_per_req, - extend_logprob_start_len_per_req, - bid, - ) = ( - result.logits_output, - result.next_token_ids, - result.extend_input_len_per_req, - result.extend_logprob_start_len_per_req, - result.bid, - ) - - if self.enable_overlap: - logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) - else: - # Move next_token_ids and logprobs to cpu - next_token_ids = next_token_ids.tolist() - if batch.return_logprob: - if logits_output.next_token_logprobs is not None: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.tolist() - ) - if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = tuple( - logits_output.input_token_logprobs.tolist() - ) - - hidden_state_offset = 0 - - # Check finish conditions - logprob_pt = 0 - for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): - if req.is_retracted: - continue - - if self.is_mixed_chunk and self.enable_overlap and req.finished(): - # Free the one delayed token for the mixed decode batch - j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1]) - continue - - if req.is_chunked <= 0: - # req output_ids are set here - req.output_ids.append(next_token_id) - req.check_finished() - - if req.finished(): - self.tree_cache.cache_finished_req(req) - elif not batch.decoding_reqs or req not in batch.decoding_reqs: - # This updates radix so others can match - self.tree_cache.cache_unfinished_req(req) - - if req.return_logprob: - assert extend_logprob_start_len_per_req is not None - assert extend_input_len_per_req is not None - extend_logprob_start_len = extend_logprob_start_len_per_req[i] - extend_input_len = extend_input_len_per_req[i] - num_input_logprobs = extend_input_len - extend_logprob_start_len - self.add_logprob_return_values( - i, - req, - logprob_pt, - next_token_ids, - num_input_logprobs, - logits_output, - ) - logprob_pt += num_input_logprobs - - if ( - req.return_hidden_states - and logits_output.hidden_states is not None - ): - req.hidden_states.append( - logits_output.hidden_states[ - hidden_state_offset : ( - hidden_state_offset := hidden_state_offset - + len(req.origin_input_ids) - ) - ] - .cpu() - .clone() - ) - - if req.grammar is not None: - req.grammar.accept_token(next_token_id) - req.grammar.finished = req.finished() - else: - # being chunked reqs' prefill is not finished - req.is_chunked -= 1 - # There is only at most one request being currently chunked. - # Because this request does not finish prefill, - # we don't want to stream the request currently being chunked. - skip_stream_req = req - - # Incrementally update input logprobs. - if req.return_logprob: - extend_logprob_start_len = extend_logprob_start_len_per_req[i] - extend_input_len = extend_input_len_per_req[i] - if extend_logprob_start_len < extend_input_len: - # Update input logprobs. - num_input_logprobs = ( - extend_input_len - extend_logprob_start_len - ) - self.add_input_logprob_return_values( - i, - req, - logits_output, - logprob_pt, - num_input_logprobs, - last_prefill_chunk=False, - ) - logprob_pt += num_input_logprobs - - if batch.next_batch_sampling_info: - batch.next_batch_sampling_info.update_regex_vocab_mask() - self.current_stream.synchronize() - batch.next_batch_sampling_info.sampling_info_done.set() - - else: # embedding or reward model - embeddings, bid = result.embeddings, result.bid - embeddings = embeddings.tolist() - - # Check finish conditions - for i, req in enumerate(batch.reqs): - if req.is_retracted: - continue - - req.embedding = embeddings[i] - if req.is_chunked <= 0: - # Dummy output token for embedding models - req.output_ids.append(0) - req.check_finished() - - if req.finished(): - self.tree_cache.cache_finished_req(req) - else: - self.tree_cache.cache_unfinished_req(req) - else: - # being chunked reqs' prefill is not finished - req.is_chunked -= 1 - - self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) - - def process_batch_result_decode( - self, - batch: ScheduleBatch, - result: GenerationBatchResult, - ): - logits_output, next_token_ids, bid = ( - result.logits_output, - result.next_token_ids, - result.bid, - ) - self.num_generated_tokens += len(batch.reqs) - - if self.enable_overlap: - assert batch.spec_algorithm.is_none() - logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) - next_token_logprobs = logits_output.next_token_logprobs - elif batch.spec_algorithm.is_none(): - # spec decoding handles output logprobs inside verify process. - next_token_ids = next_token_ids.tolist() - if batch.return_logprob: - next_token_logprobs = logits_output.next_token_logprobs.tolist() - - self.token_to_kv_pool_allocator.free_group_begin() - - # Check finish condition - # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding. - # We should ignore using next_token_ids for spec decoding cases. - for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): - if req.is_retracted: - continue - - if self.enable_overlap and req.finished(): - # Free the one delayed token - self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) - continue - - if batch.spec_algorithm.is_none(): - # speculative worker will solve the output_ids in speculative decoding - req.output_ids.append(next_token_id) - - req.check_finished() - if req.finished(): - self.tree_cache.cache_finished_req(req) - - if req.return_logprob and batch.spec_algorithm.is_none(): - # speculative worker handles logprob in speculative decoding - req.output_token_logprobs_val.append(next_token_logprobs[i]) - req.output_token_logprobs_idx.append(next_token_id) - if req.top_logprobs_num > 0: - req.output_top_logprobs_val.append( - logits_output.next_token_top_logprobs_val[i] - ) - req.output_top_logprobs_idx.append( - logits_output.next_token_top_logprobs_idx[i] - ) - if req.token_ids_logprob is not None: - req.output_token_ids_logprobs_val.append( - logits_output.next_token_token_ids_logprobs_val[i] - ) - req.output_token_ids_logprobs_idx.append( - logits_output.next_token_token_ids_logprobs_idx[i] - ) - - if req.return_hidden_states and logits_output.hidden_states is not None: - req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) - - if req.grammar is not None and batch.spec_algorithm.is_none(): - req.grammar.accept_token(next_token_id) - req.grammar.finished = req.finished() - - if batch.next_batch_sampling_info: - batch.next_batch_sampling_info.update_regex_vocab_mask() - self.current_stream.synchronize() - batch.next_batch_sampling_info.sampling_info_done.set() - - self.stream_output(batch.reqs, batch.return_logprob) - - self.token_to_kv_pool_allocator.free_group_end() - - self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) - if ( - self.attn_tp_rank == 0 - and self.forward_ct_decode % self.server_args.decode_log_interval == 0 - ): - self.log_decode_stats() - - def add_input_logprob_return_values( - self, - i: int, - req: Req, - output: LogitsProcessorOutput, - logprob_pt: int, - num_input_logprobs: int, - last_prefill_chunk: bool, # If True, it means prefill is finished. - ): - """Incrementally add input logprobs to `req`. - - Args: - i: The request index in a batch. - req: The request. Input logprobs inside req are modified as a - consequence of the API - fill_ids: The prefill ids processed. - output: Logit processor output that's used to compute input logprobs - last_prefill_chunk: True if it is the last prefill (when chunked). - Some of input logprob operation should only happen at the last - prefill (e.g., computing input token logprobs). - """ - assert output.input_token_logprobs is not None - if req.input_token_logprobs is None: - req.input_token_logprobs = [] - if req.temp_input_top_logprobs_val is None: - req.temp_input_top_logprobs_val = [] - if req.temp_input_top_logprobs_idx is None: - req.temp_input_top_logprobs_idx = [] - if req.temp_input_token_ids_logprobs_val is None: - req.temp_input_token_ids_logprobs_val = [] - if req.temp_input_token_ids_logprobs_idx is None: - req.temp_input_token_ids_logprobs_idx = [] - - if req.input_token_logprobs_val is not None: - # The input logprob has been already computed. It only happens - # upon retract. - if req.top_logprobs_num > 0: - assert req.input_token_logprobs_val is not None - return - - # Important for the performance. - assert isinstance(output.input_token_logprobs, tuple) - input_token_logprobs: Tuple[int] = output.input_token_logprobs - input_token_logprobs = input_token_logprobs[ - logprob_pt : logprob_pt + num_input_logprobs - ] - req.input_token_logprobs.extend(input_token_logprobs) - - if req.top_logprobs_num > 0: - req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i]) - req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i]) - - if req.token_ids_logprob is not None: - req.temp_input_token_ids_logprobs_val.append( - output.input_token_ids_logprobs_val[i] - ) - req.temp_input_token_ids_logprobs_idx.append( - output.input_token_ids_logprobs_idx[i] - ) - - if last_prefill_chunk: - input_token_logprobs = req.input_token_logprobs - req.input_token_logprobs = None - assert req.input_token_logprobs_val is None - assert req.input_token_logprobs_idx is None - assert req.input_top_logprobs_val is None - assert req.input_top_logprobs_idx is None - - # Compute input_token_logprobs_val - # Always pad the first one with None. - req.input_token_logprobs_val = [None] - req.input_token_logprobs_val.extend(input_token_logprobs) - # The last input logprob is for sampling, so just pop it out. - req.input_token_logprobs_val.pop() - - # Compute input_token_logprobs_idx - input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :] - # Clip the padded hash values from image tokens. - # Otherwise, it will lead to detokenization errors. - input_token_logprobs_idx = [ - x if x < self.model_config.vocab_size - 1 else 0 - for x in input_token_logprobs_idx - ] - req.input_token_logprobs_idx = input_token_logprobs_idx - - if req.top_logprobs_num > 0: - req.input_top_logprobs_val = [None] - req.input_top_logprobs_idx = [None] - assert len(req.temp_input_token_ids_logprobs_val) == len( - req.temp_input_token_ids_logprobs_idx - ) - for val, idx in zip( - req.temp_input_top_logprobs_val, - req.temp_input_top_logprobs_idx, - strict=True, - ): - req.input_top_logprobs_val.extend(val) - req.input_top_logprobs_idx.extend(idx) - - # Last token is a sample token. - req.input_top_logprobs_val.pop() - req.input_top_logprobs_idx.pop() - req.temp_input_top_logprobs_idx = None - req.temp_input_top_logprobs_val = None - - if req.token_ids_logprob is not None: - req.input_token_ids_logprobs_val = [None] - req.input_token_ids_logprobs_idx = [None] - - for val, idx in zip( - req.temp_input_token_ids_logprobs_val, - req.temp_input_token_ids_logprobs_idx, - strict=True, - ): - req.input_token_ids_logprobs_val.extend(val) - req.input_token_ids_logprobs_idx.extend(idx) - - # Last token is a sample token. - req.input_token_ids_logprobs_val.pop() - req.input_token_ids_logprobs_idx.pop() - req.temp_input_token_ids_logprobs_idx = None - req.temp_input_token_ids_logprobs_val = None - - if req.return_logprob: - relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len - assert len(req.input_token_logprobs_val) == relevant_tokens_len - assert len(req.input_token_logprobs_idx) == relevant_tokens_len - if req.top_logprobs_num > 0: - assert len(req.input_top_logprobs_val) == relevant_tokens_len - assert len(req.input_top_logprobs_idx) == relevant_tokens_len - if req.token_ids_logprob is not None: - assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len - assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len - - def add_logprob_return_values( - self, - i: int, - req: Req, - pt: int, - next_token_ids: List[int], - num_input_logprobs: int, - output: LogitsProcessorOutput, - ): - """Attach logprobs to the return values.""" - req.output_token_logprobs_val.append(output.next_token_logprobs[i]) - req.output_token_logprobs_idx.append(next_token_ids[i]) - - self.add_input_logprob_return_values( - i, req, output, pt, num_input_logprobs, last_prefill_chunk=True - ) - - if req.top_logprobs_num > 0: - req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) - req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) - - if req.token_ids_logprob is not None: - req.output_token_ids_logprobs_val.append( - output.next_token_token_ids_logprobs_val[i] - ) - req.output_token_ids_logprobs_idx.append( - output.next_token_token_ids_logprobs_idx[i] - ) - - return num_input_logprobs - - def stream_output( - self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None - ): - """Stream the output to detokenizer.""" - rids = [] - finished_reasons: List[BaseFinishReason] = [] - - if self.is_generation: - decoded_texts = [] - decode_ids_list = [] - read_offsets = [] - output_ids = [] - - skip_special_tokens = [] - spaces_between_special_tokens = [] - no_stop_trim = [] - prompt_tokens = [] - completion_tokens = [] - cached_tokens = [] - spec_verify_ct = [] - output_hidden_states = None - - if return_logprob: - input_token_logprobs_val = [] - input_token_logprobs_idx = [] - output_token_logprobs_val = [] - output_token_logprobs_idx = [] - input_top_logprobs_val = [] - input_top_logprobs_idx = [] - output_top_logprobs_val = [] - output_top_logprobs_idx = [] - input_token_ids_logprobs_val = [] - input_token_ids_logprobs_idx = [] - output_token_ids_logprobs_val = [] - output_token_ids_logprobs_idx = [] - else: - input_token_logprobs_val = input_token_logprobs_idx = ( - output_token_logprobs_val - ) = output_token_logprobs_idx = input_top_logprobs_val = ( - input_top_logprobs_idx - ) = output_top_logprobs_val = output_top_logprobs_idx = ( - input_token_ids_logprobs_val - ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = ( - output_token_ids_logprobs_idx - ) = None - - for req in reqs: - if req is skip_req: - continue - - # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here. - if self.model_config.is_multimodal_gen and req.to_abort: - continue - - if ( - req.finished() - # If stream, follow the given stream_interval - or (req.stream and len(req.output_ids) % self.stream_interval == 0) - # If not stream, we still want to output some tokens to get the benefit of incremental decoding. - # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not - # always increase one-by-one. - or ( - not req.stream - and len(req.output_ids) % 50 == 0 - and not self.model_config.is_multimodal_gen - ) - ): - rids.append(req.rid) - finished_reasons.append( - req.finished_reason.to_json() if req.finished_reason else None - ) - decoded_texts.append(req.decoded_text) - decode_ids, read_offset = req.init_incremental_detokenize() - decode_ids_list.append(decode_ids) - read_offsets.append(read_offset) - if self.skip_tokenizer_init: - output_ids.append(req.output_ids) - skip_special_tokens.append(req.sampling_params.skip_special_tokens) - spaces_between_special_tokens.append( - req.sampling_params.spaces_between_special_tokens - ) - no_stop_trim.append(req.sampling_params.no_stop_trim) - - prompt_tokens.append(len(req.origin_input_ids)) - completion_tokens.append(len(req.output_ids)) - cached_tokens.append(req.cached_tokens) - - if not self.spec_algorithm.is_none(): - spec_verify_ct.append(req.spec_verify_ct) - - if return_logprob: - input_token_logprobs_val.append(req.input_token_logprobs_val) - input_token_logprobs_idx.append(req.input_token_logprobs_idx) - output_token_logprobs_val.append(req.output_token_logprobs_val) - output_token_logprobs_idx.append(req.output_token_logprobs_idx) - input_top_logprobs_val.append(req.input_top_logprobs_val) - input_top_logprobs_idx.append(req.input_top_logprobs_idx) - output_top_logprobs_val.append(req.output_top_logprobs_val) - output_top_logprobs_idx.append(req.output_top_logprobs_idx) - input_token_ids_logprobs_val.append( - req.input_token_ids_logprobs_val - ) - input_token_ids_logprobs_idx.append( - req.input_token_ids_logprobs_idx - ) - output_token_ids_logprobs_val.append( - req.output_token_ids_logprobs_val - ) - output_token_ids_logprobs_idx.append( - req.output_token_ids_logprobs_idx - ) - - if req.return_hidden_states: - if output_hidden_states is None: - output_hidden_states = [] - output_hidden_states.append(req.hidden_states) - - # Send to detokenizer - if rids: - if self.model_config.is_multimodal_gen: - raise NotImplementedError() - self.send_to_detokenizer.send_pyobj( - BatchTokenIDOut( - rids, - finished_reasons, - decoded_texts, - decode_ids_list, - read_offsets, - output_ids, - skip_special_tokens, - spaces_between_special_tokens, - no_stop_trim, - prompt_tokens, - completion_tokens, - cached_tokens, - spec_verify_ct, - input_token_logprobs_val, - input_token_logprobs_idx, - output_token_logprobs_val, - output_token_logprobs_idx, - input_top_logprobs_val, - input_top_logprobs_idx, - output_top_logprobs_val, - output_top_logprobs_idx, - input_token_ids_logprobs_val, - input_token_ids_logprobs_idx, - output_token_ids_logprobs_val, - output_token_ids_logprobs_idx, - output_hidden_states, - ) - ) - else: # embedding or reward model - embeddings = [] - prompt_tokens = [] - cached_tokens = [] - for req in reqs: - if req.finished(): - rids.append(req.rid) - finished_reasons.append(req.finished_reason.to_json()) - embeddings.append(req.embedding) - prompt_tokens.append(len(req.origin_input_ids)) - cached_tokens.append(req.cached_tokens) - self.send_to_detokenizer.send_pyobj( - BatchEmbeddingOut( - rids, finished_reasons, embeddings, prompt_tokens, cached_tokens - ) - ) - def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # Check if other DP workers have running batches if local_batch is None: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py new file mode 100644 index 000000000..8728b9e7e --- /dev/null +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -0,0 +1,602 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut +from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch + +if TYPE_CHECKING: + from sglang.srt.managers.scheduler import ( + EmbeddingBatchResult, + GenerationBatchResult, + ScheduleBatch, + ) + + +class SchedulerOutputProcessorMixin: + """ + This class implements the output processing logic for Scheduler. + We put them into a separate file to make the `scheduler.py` shorter. + """ + + def process_batch_result_prefill( + self, + batch: ScheduleBatch, + result: Union[GenerationBatchResult, EmbeddingBatchResult], + ): + skip_stream_req = None + + if self.is_generation: + ( + logits_output, + next_token_ids, + extend_input_len_per_req, + extend_logprob_start_len_per_req, + bid, + ) = ( + result.logits_output, + result.next_token_ids, + result.extend_input_len_per_req, + result.extend_logprob_start_len_per_req, + result.bid, + ) + + if self.enable_overlap: + logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) + else: + # Move next_token_ids and logprobs to cpu + next_token_ids = next_token_ids.tolist() + if batch.return_logprob: + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.tolist() + ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = tuple( + logits_output.input_token_logprobs.tolist() + ) + + hidden_state_offset = 0 + + # Check finish conditions + logprob_pt = 0 + for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): + if req.is_retracted: + continue + + if self.is_mixed_chunk and self.enable_overlap and req.finished(): + # Free the one delayed token for the mixed decode batch + j = len(batch.out_cache_loc) - len(batch.reqs) + i + self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1]) + continue + + if req.is_chunked <= 0: + # req output_ids are set here + req.output_ids.append(next_token_id) + req.check_finished() + + if req.finished(): + self.tree_cache.cache_finished_req(req) + elif not batch.decoding_reqs or req not in batch.decoding_reqs: + # This updates radix so others can match + self.tree_cache.cache_unfinished_req(req) + + if req.return_logprob: + assert extend_logprob_start_len_per_req is not None + assert extend_input_len_per_req is not None + extend_logprob_start_len = extend_logprob_start_len_per_req[i] + extend_input_len = extend_input_len_per_req[i] + num_input_logprobs = extend_input_len - extend_logprob_start_len + self.add_logprob_return_values( + i, + req, + logprob_pt, + next_token_ids, + num_input_logprobs, + logits_output, + ) + logprob_pt += num_input_logprobs + + if ( + req.return_hidden_states + and logits_output.hidden_states is not None + ): + req.hidden_states.append( + logits_output.hidden_states[ + hidden_state_offset : ( + hidden_state_offset := hidden_state_offset + + len(req.origin_input_ids) + ) + ] + .cpu() + .clone() + ) + + if req.grammar is not None: + req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() + else: + # being chunked reqs' prefill is not finished + req.is_chunked -= 1 + # There is only at most one request being currently chunked. + # Because this request does not finish prefill, + # we don't want to stream the request currently being chunked. + skip_stream_req = req + + # Incrementally update input logprobs. + if req.return_logprob: + extend_logprob_start_len = extend_logprob_start_len_per_req[i] + extend_input_len = extend_input_len_per_req[i] + if extend_logprob_start_len < extend_input_len: + # Update input logprobs. + num_input_logprobs = ( + extend_input_len - extend_logprob_start_len + ) + self.add_input_logprob_return_values( + i, + req, + logits_output, + logprob_pt, + num_input_logprobs, + last_prefill_chunk=False, + ) + logprob_pt += num_input_logprobs + + if batch.next_batch_sampling_info: + batch.next_batch_sampling_info.update_regex_vocab_mask() + self.current_stream.synchronize() + batch.next_batch_sampling_info.sampling_info_done.set() + + else: # embedding or reward model + embeddings, bid = result.embeddings, result.bid + embeddings = embeddings.tolist() + + # Check finish conditions + for i, req in enumerate(batch.reqs): + if req.is_retracted: + continue + + req.embedding = embeddings[i] + if req.is_chunked <= 0: + # Dummy output token for embedding models + req.output_ids.append(0) + req.check_finished() + + if req.finished(): + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_unfinished_req(req) + else: + # being chunked reqs' prefill is not finished + req.is_chunked -= 1 + + self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) + + def process_batch_result_decode( + self, + batch: ScheduleBatch, + result: GenerationBatchResult, + ): + logits_output, next_token_ids, bid = ( + result.logits_output, + result.next_token_ids, + result.bid, + ) + self.num_generated_tokens += len(batch.reqs) + + if self.enable_overlap: + logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) + next_token_logprobs = logits_output.next_token_logprobs + elif batch.spec_algorithm.is_none(): + # spec decoding handles output logprobs inside verify process. + next_token_ids = next_token_ids.tolist() + if batch.return_logprob: + next_token_logprobs = logits_output.next_token_logprobs.tolist() + + self.token_to_kv_pool_allocator.free_group_begin() + + # Check finish condition + # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding. + # We should ignore using next_token_ids for spec decoding cases. + for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): + if req.is_retracted: + continue + + if self.enable_overlap and req.finished(): + # Free the one delayed token + self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) + continue + + if batch.spec_algorithm.is_none(): + # speculative worker will solve the output_ids in speculative decoding + req.output_ids.append(next_token_id) + + req.check_finished() + if req.finished(): + self.tree_cache.cache_finished_req(req) + + if req.return_logprob and batch.spec_algorithm.is_none(): + # speculative worker handles logprob in speculative decoding + req.output_token_logprobs_val.append(next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_id) + if req.top_logprobs_num > 0: + req.output_top_logprobs_val.append( + logits_output.next_token_top_logprobs_val[i] + ) + req.output_top_logprobs_idx.append( + logits_output.next_token_top_logprobs_idx[i] + ) + if req.token_ids_logprob is not None: + req.output_token_ids_logprobs_val.append( + logits_output.next_token_token_ids_logprobs_val[i] + ) + req.output_token_ids_logprobs_idx.append( + logits_output.next_token_token_ids_logprobs_idx[i] + ) + + if req.return_hidden_states and logits_output.hidden_states is not None: + req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) + + if req.grammar is not None and batch.spec_algorithm.is_none(): + req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() + + if batch.next_batch_sampling_info: + batch.next_batch_sampling_info.update_regex_vocab_mask() + self.current_stream.synchronize() + batch.next_batch_sampling_info.sampling_info_done.set() + + self.stream_output(batch.reqs, batch.return_logprob) + + self.token_to_kv_pool_allocator.free_group_end() + + self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) + if ( + self.attn_tp_rank == 0 + and self.forward_ct_decode % self.server_args.decode_log_interval == 0 + ): + self.log_decode_stats() + + def add_input_logprob_return_values( + self, + i: int, + req: Req, + output: LogitsProcessorOutput, + logprob_pt: int, + num_input_logprobs: int, + last_prefill_chunk: bool, # If True, it means prefill is finished. + ): + """Incrementally add input logprobs to `req`. + + Args: + i: The request index in a batch. + req: The request. Input logprobs inside req are modified as a + consequence of the API + fill_ids: The prefill ids processed. + output: Logit processor output that's used to compute input logprobs + last_prefill_chunk: True if it is the last prefill (when chunked). + Some of input logprob operation should only happen at the last + prefill (e.g., computing input token logprobs). + """ + assert output.input_token_logprobs is not None + if req.input_token_logprobs is None: + req.input_token_logprobs = [] + if req.temp_input_top_logprobs_val is None: + req.temp_input_top_logprobs_val = [] + if req.temp_input_top_logprobs_idx is None: + req.temp_input_top_logprobs_idx = [] + if req.temp_input_token_ids_logprobs_val is None: + req.temp_input_token_ids_logprobs_val = [] + if req.temp_input_token_ids_logprobs_idx is None: + req.temp_input_token_ids_logprobs_idx = [] + + if req.input_token_logprobs_val is not None: + # The input logprob has been already computed. It only happens + # upon retract. + if req.top_logprobs_num > 0: + assert req.input_token_logprobs_val is not None + return + + # Important for the performance. + assert isinstance(output.input_token_logprobs, tuple) + input_token_logprobs: Tuple[int] = output.input_token_logprobs + input_token_logprobs = input_token_logprobs[ + logprob_pt : logprob_pt + num_input_logprobs + ] + req.input_token_logprobs.extend(input_token_logprobs) + + if req.top_logprobs_num > 0: + req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i]) + req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i]) + + if req.token_ids_logprob is not None: + req.temp_input_token_ids_logprobs_val.append( + output.input_token_ids_logprobs_val[i] + ) + req.temp_input_token_ids_logprobs_idx.append( + output.input_token_ids_logprobs_idx[i] + ) + + if last_prefill_chunk: + input_token_logprobs = req.input_token_logprobs + req.input_token_logprobs = None + assert req.input_token_logprobs_val is None + assert req.input_token_logprobs_idx is None + assert req.input_top_logprobs_val is None + assert req.input_top_logprobs_idx is None + + # Compute input_token_logprobs_val + # Always pad the first one with None. + req.input_token_logprobs_val = [None] + req.input_token_logprobs_val.extend(input_token_logprobs) + # The last input logprob is for sampling, so just pop it out. + req.input_token_logprobs_val.pop() + + # Compute input_token_logprobs_idx + input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :] + # Clip the padded hash values from image tokens. + # Otherwise, it will lead to detokenization errors. + input_token_logprobs_idx = [ + x if x < self.model_config.vocab_size - 1 else 0 + for x in input_token_logprobs_idx + ] + req.input_token_logprobs_idx = input_token_logprobs_idx + + if req.top_logprobs_num > 0: + req.input_top_logprobs_val = [None] + req.input_top_logprobs_idx = [None] + assert len(req.temp_input_token_ids_logprobs_val) == len( + req.temp_input_token_ids_logprobs_idx + ) + for val, idx in zip( + req.temp_input_top_logprobs_val, + req.temp_input_top_logprobs_idx, + strict=True, + ): + req.input_top_logprobs_val.extend(val) + req.input_top_logprobs_idx.extend(idx) + + # Last token is a sample token. + req.input_top_logprobs_val.pop() + req.input_top_logprobs_idx.pop() + req.temp_input_top_logprobs_idx = None + req.temp_input_top_logprobs_val = None + + if req.token_ids_logprob is not None: + req.input_token_ids_logprobs_val = [None] + req.input_token_ids_logprobs_idx = [None] + + for val, idx in zip( + req.temp_input_token_ids_logprobs_val, + req.temp_input_token_ids_logprobs_idx, + strict=True, + ): + req.input_token_ids_logprobs_val.extend(val) + req.input_token_ids_logprobs_idx.extend(idx) + + # Last token is a sample token. + req.input_token_ids_logprobs_val.pop() + req.input_token_ids_logprobs_idx.pop() + req.temp_input_token_ids_logprobs_idx = None + req.temp_input_token_ids_logprobs_val = None + + if req.return_logprob: + relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len + assert len(req.input_token_logprobs_val) == relevant_tokens_len + assert len(req.input_token_logprobs_idx) == relevant_tokens_len + if req.top_logprobs_num > 0: + assert len(req.input_top_logprobs_val) == relevant_tokens_len + assert len(req.input_top_logprobs_idx) == relevant_tokens_len + if req.token_ids_logprob is not None: + assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len + assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len + + def add_logprob_return_values( + self, + i: int, + req: Req, + pt: int, + next_token_ids: List[int], + num_input_logprobs: int, + output: LogitsProcessorOutput, + ): + """Attach logprobs to the return values.""" + req.output_token_logprobs_val.append(output.next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_ids[i]) + + self.add_input_logprob_return_values( + i, req, output, pt, num_input_logprobs, last_prefill_chunk=True + ) + + if req.top_logprobs_num > 0: + req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) + + if req.token_ids_logprob is not None: + req.output_token_ids_logprobs_val.append( + output.next_token_token_ids_logprobs_val[i] + ) + req.output_token_ids_logprobs_idx.append( + output.next_token_token_ids_logprobs_idx[i] + ) + + return num_input_logprobs + + def stream_output( + self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None + ): + """Stream the output to detokenizer.""" + if self.is_generation: + self.stream_output_generation(reqs, return_logprob, skip_req) + else: # embedding or reward model + self.stream_output_embedding(reqs) + + def stream_output_generation( + self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None + ): + rids = [] + finished_reasons: List[BaseFinishReason] = [] + + decoded_texts = [] + decode_ids_list = [] + read_offsets = [] + output_ids = [] + + skip_special_tokens = [] + spaces_between_special_tokens = [] + no_stop_trim = [] + prompt_tokens = [] + completion_tokens = [] + cached_tokens = [] + spec_verify_ct = [] + output_hidden_states = None + + if return_logprob: + input_token_logprobs_val = [] + input_token_logprobs_idx = [] + output_token_logprobs_val = [] + output_token_logprobs_idx = [] + input_top_logprobs_val = [] + input_top_logprobs_idx = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] + input_token_ids_logprobs_val = [] + input_token_ids_logprobs_idx = [] + output_token_ids_logprobs_val = [] + output_token_ids_logprobs_idx = [] + else: + input_token_logprobs_val = input_token_logprobs_idx = ( + output_token_logprobs_val + ) = output_token_logprobs_idx = input_top_logprobs_val = ( + input_top_logprobs_idx + ) = output_top_logprobs_val = output_top_logprobs_idx = ( + input_token_ids_logprobs_val + ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = ( + output_token_ids_logprobs_idx + ) = None + + for req in reqs: + if req is skip_req: + continue + + # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here. + if self.model_config.is_multimodal_gen and req.to_abort: + continue + + if ( + req.finished() + # If stream, follow the given stream_interval + or (req.stream and len(req.output_ids) % self.stream_interval == 0) + # If not stream, we still want to output some tokens to get the benefit of incremental decoding. + # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not + # always increase one-by-one. + or ( + not req.stream + and len(req.output_ids) % 50 == 0 + and not self.model_config.is_multimodal_gen + ) + ): + rids.append(req.rid) + finished_reasons.append( + req.finished_reason.to_json() if req.finished_reason else None + ) + decoded_texts.append(req.decoded_text) + decode_ids, read_offset = req.init_incremental_detokenize() + decode_ids_list.append(decode_ids) + read_offsets.append(read_offset) + if self.skip_tokenizer_init: + output_ids.append(req.output_ids) + skip_special_tokens.append(req.sampling_params.skip_special_tokens) + spaces_between_special_tokens.append( + req.sampling_params.spaces_between_special_tokens + ) + no_stop_trim.append(req.sampling_params.no_stop_trim) + prompt_tokens.append(len(req.origin_input_ids)) + completion_tokens.append(len(req.output_ids)) + cached_tokens.append(req.cached_tokens) + + if not self.spec_algorithm.is_none(): + spec_verify_ct.append(req.spec_verify_ct) + + if return_logprob: + input_token_logprobs_val.append(req.input_token_logprobs_val) + input_token_logprobs_idx.append(req.input_token_logprobs_idx) + output_token_logprobs_val.append(req.output_token_logprobs_val) + output_token_logprobs_idx.append(req.output_token_logprobs_idx) + input_top_logprobs_val.append(req.input_top_logprobs_val) + input_top_logprobs_idx.append(req.input_top_logprobs_idx) + output_top_logprobs_val.append(req.output_top_logprobs_val) + output_top_logprobs_idx.append(req.output_top_logprobs_idx) + input_token_ids_logprobs_val.append( + req.input_token_ids_logprobs_val + ) + input_token_ids_logprobs_idx.append( + req.input_token_ids_logprobs_idx + ) + output_token_ids_logprobs_val.append( + req.output_token_ids_logprobs_val + ) + output_token_ids_logprobs_idx.append( + req.output_token_ids_logprobs_idx + ) + + if req.return_hidden_states: + if output_hidden_states is None: + output_hidden_states = [] + output_hidden_states.append(req.hidden_states) + + # Send to detokenizer + if rids: + if self.model_config.is_multimodal_gen: + return + self.send_to_detokenizer.send_pyobj( + BatchTokenIDOut( + rids, + finished_reasons, + decoded_texts, + decode_ids_list, + read_offsets, + output_ids, + skip_special_tokens, + spaces_between_special_tokens, + no_stop_trim, + prompt_tokens, + completion_tokens, + cached_tokens, + spec_verify_ct, + input_token_logprobs_val, + input_token_logprobs_idx, + output_token_logprobs_val, + output_token_logprobs_idx, + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + input_token_ids_logprobs_val, + input_token_ids_logprobs_idx, + output_token_ids_logprobs_val, + output_token_ids_logprobs_idx, + output_hidden_states, + ) + ) + + def stream_output_embedding(self, reqs: List[Req]): + rids = [] + finished_reasons: List[BaseFinishReason] = [] + + embeddings = [] + prompt_tokens = [] + cached_tokens = [] + for req in reqs: + if req.finished(): + rids.append(req.rid) + finished_reasons.append(req.finished_reason.to_json()) + embeddings.append(req.embedding) + prompt_tokens.append(len(req.origin_input_ids)) + cached_tokens.append(req.cached_tokens) + self.send_to_detokenizer.send_pyobj( + BatchEmbeddingOut( + rids, finished_reasons, embeddings, prompt_tokens, cached_tokens + ) + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index cb2069dda..916d595ed 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -82,7 +82,6 @@ from sglang.srt.utils import ( logger = logging.getLogger(__name__) - SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 @@ -119,6 +118,7 @@ class ModelRunner: self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) + self.page_size = server_args.page_size self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator @@ -161,6 +161,11 @@ class ModelRunner: # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + # If it is a draft model tp_group can be different. + self.initialize(min_per_gpu_memory) + + def initialize(self, min_per_gpu_memory: float): + server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=self.server_args.enable_memory_saver ) @@ -300,15 +305,16 @@ class ModelRunner: min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) - local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) self.tp_group = get_tp_group() self.attention_tp_group = get_attention_tp_group() # Check memory for tensor parallelism + local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) if self.tp_size > 1: if min_per_gpu_memory < local_gpu_memory * 0.9: raise ValueError( - "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " + f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" ) logger.info( @@ -698,6 +704,12 @@ class ModelRunner: ) self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) + self.max_total_num_tokens = ( + self.max_total_num_tokens + // self.server_args.page_size + * self.server_args.page_size + ) + if self.max_total_num_tokens <= 0: raise RuntimeError( "Not enough memory. Please try to increase --mem-fraction-static." @@ -783,7 +795,6 @@ class ModelRunner: # Init streams if self.server_args.speculative_algorithm == "EAGLE": self.plan_stream_for_flashinfer = torch.cuda.Stream() - self.attn_backend = FlashInferAttnBackend(self) elif self.server_args.attention_backend == "triton": assert self.sliding_window_size is None, ( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6c52709f7..bf6fe4b94 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,14 +20,13 @@ import random import tempfile from typing import List, Optional -import torch - from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_hpu_memory_capacity, get_nvgpu_memory_capacity, + is_cuda, is_flashinfer_available, is_hip, is_port_available, @@ -71,6 +70,7 @@ class ServerArgs: schedule_policy: str = "fcfs" schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 + page_size: int = 1 # Other runtime options tp_size: int = 1 @@ -190,10 +190,10 @@ class ServerArgs: if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) - if is_hip(): - gpu_mem = get_amdgpu_memory_capacity() - elif torch.cuda.is_available(): + if is_cuda(): gpu_mem = get_nvgpu_memory_capacity() + elif is_hip(): + gpu_mem = get_amdgpu_memory_capacity() elif self.device == "hpu": gpu_mem = get_hpu_memory_capacity() else: @@ -258,7 +258,7 @@ class ServerArgs: f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) - # Others + # Data parallelism attention if self.enable_dp_attention: self.dp_size = self.tp_size assert self.tp_size % self.dp_size == 0 @@ -507,6 +507,12 @@ class ServerArgs: default=ServerArgs.cpu_offload_gb, help="How many GBs of RAM to reserve for CPU offloading.", ) + parser.add_argument( + "--page-size", + type=int, + default=ServerArgs.page_size, + help="The number of tokens in a page.", + ) # Other runtime options parser.add_argument(