From d18c6b3358185cee49db931727c9f99018da7f5d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 12 May 2025 14:33:38 -0700 Subject: [PATCH] Support incremental streaming of logprob/token_ids between scheduler and detokenizer (#6225) Co-authored-by: SangBin Cho --- .../srt/constrained/base_grammar_backend.py | 2 +- .../srt/managers/detokenizer_manager.py | 29 ++-- python/sglang/srt/managers/schedule_batch.py | 43 +++++- python/sglang/srt/managers/scheduler.py | 23 ++-- .../scheduler_output_processor_mixin.py | 128 +++++++++++++----- .../sglang/srt/managers/tokenizer_manager.py | 100 +++++++++++--- .../srt/managers/tp_worker_overlap_thread.py | 9 +- .../srt/sampling/sampling_batch_info.py | 2 +- python/sglang/srt/server_args.py | 7 + 9 files changed, 257 insertions(+), 86 deletions(-) diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index f097d4c08..8356eed9f 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -41,7 +41,7 @@ class BaseGrammarObject: raise NotImplementedError() def is_terminated(self): - raise NotImplementedError() + return False def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index aeae266eb..811f108c7 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.io_struct import ( BatchEmbeddingOut, BatchMultimodalDecodeReq, + BatchMultimodalOut, BatchStrOut, BatchTokenIDOut, ) @@ -60,6 +61,8 @@ class DecodeStatus: decode_ids: List[int] surr_offset: int read_offset: int + # Offset that's sent to tokenizer for incremental update. + sent_offset: int = 0 class DetokenizerManager: @@ -151,7 +154,7 @@ class DetokenizerManager: self.decode_status[rid] = s else: s = self.decode_status[rid] - s.decode_ids = recv_obj.decode_ids[i] + s.decode_ids.extend(recv_obj.decode_ids[i]) read_ids.append( self.trim_matched_stop( @@ -199,13 +202,15 @@ class DetokenizerManager: else: new_text = find_printable_text(new_text) - output_strs.append( - self.trim_matched_stop( - s.decoded_text + new_text, - recv_obj.finished_reasons[i], - recv_obj.no_stop_trim[i], - ) + output_str = self.trim_matched_stop( + s.decoded_text + new_text, + recv_obj.finished_reasons[i], + recv_obj.no_stop_trim[i], ) + # Incrementally send text. + incremental_output = output_str[s.sent_offset :] + s.sent_offset = len(output_str) + output_strs.append(incremental_output) return BatchStrOut( rids=recv_obj.rids, @@ -232,7 +237,15 @@ class DetokenizerManager: ) def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): - raise NotImplementedError() + outputs = self.tokenizer.detokenize(recv_obj) + return BatchMultimodalOut( + rids=recv_obj.rids, + finished_reasons=recv_obj.finished_reasons, + outputs=outputs, + prompt_tokens=recv_obj.prompt_tokens, + completion_tokens=recv_obj.completion_tokens, + cached_tokens=recv_obj.cached_tokens, + ) class LimitedCapacityDict(OrderedDict): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ac4b4edcb..e9bf68b32 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator +from sglang.srt.metrics.collector import TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams @@ -436,6 +437,7 @@ class Req: self.sampling_params = sampling_params self.custom_logit_processor = custom_logit_processor self.return_hidden_states = return_hidden_states + self.lora_path = lora_path # Memory pool info self.req_pool_idx: Optional[int] = None @@ -487,6 +489,13 @@ class Req: # For retraction self.is_retracted = False + # Incremental streamining + self.send_token_offset: int = 0 + self.send_decode_id_offset: int = 0 + # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode + # because the decode server does not have the first output token logprobs + self.send_output_token_logprobs_offset: int = 0 + # Logprobs (arguments) self.return_logprob = return_logprob # Start index to compute logprob from. @@ -496,11 +505,9 @@ class Req: self.temp_scaled_logprobs = False self.top_p_normalized_logprobs = False - # Latency Breakdown - self.queue_time_start = None - self.queue_time_end = None - # Logprobs (return values) + # True means the input logprob has been already sent to detokenizer. + self.input_logprob_sent: bool = False self.input_token_logprobs_val: Optional[List[float]] = None self.input_token_logprobs_idx: Optional[List[int]] = None self.input_top_logprobs_val: Optional[List[float]] = None @@ -515,8 +522,10 @@ class Req: self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None if return_logprob: + # shape: (bs, 1) self.output_token_logprobs_val = [] self.output_token_logprobs_idx = [] + # shape: (bs, k) self.output_top_logprobs_val = [] self.output_top_logprobs_idx = [] self.output_token_ids_logprobs_val = [] @@ -543,7 +552,12 @@ class Req: # The number of verification forward passes in the speculative decoding. # This is used to compute the average acceptance length per request. self.spec_verify_ct = 0 - self.lora_path = lora_path + + # For metrics + self.time_stats: TimeStats = TimeStats() + self.has_log_time_stats: bool = False + self.queue_time_start = None + self.queue_time_end = None # For disaggregation self.bootstrap_host: str = bootstrap_host @@ -562,8 +576,8 @@ class Req: # This is because kv is not ready in `process_prefill_chunk`. # We use `tmp_end_idx` to store the end index of the kv cache to send. self.tmp_end_idx: int = -1 - self.metadata_buffer_index: int = -1 + # The first output_id transferred from prefill instance. self.transferred_output_id: Optional[int] = None @@ -656,6 +670,11 @@ class Req: ) return + if self.grammar is not None: + if self.grammar.is_terminated(): + self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1]) + return + last_token_id = self.output_ids[-1] if not self.sampling_params.ignore_eos: @@ -713,6 +732,18 @@ class Req: token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices) del self.kv_cache_cpu + def log_time_stats(self): + # If overlap schedule, we schedule one decode batch ahead so this gets called twice. + if self.has_log_time_stats is True: + return + + if self.bootstrap_room is not None: + prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})" + else: + prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})" + logger.info(f"{prefix}: {self.time_stats}") + self.has_log_time_stats = True + def __repr__(self): return ( f"Req(rid={self.rid}, " diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 68a9df309..3c974b94b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -530,10 +530,6 @@ class Scheduler( ) def init_metrics(self): - # The largest prefill length of a single request - self._largest_prefill_len: int = 0 - # The largest context length (prefill + generation) of a single request - self._largest_prefill_decode_len: int = 0 self.last_gen_throughput: float = 0.0 self.last_input_throughput: float = 0.0 self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] @@ -1122,9 +1118,6 @@ class Scheduler( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() ) - self._largest_prefill_len = max( - self._largest_prefill_len, adder.log_input_tokens - ) num_new_seq = len(can_run_list) f = ( @@ -1601,14 +1594,9 @@ class Scheduler( elif batch.forward_mode.is_idle(): if self.enable_overlap: self.tp_worker.resolve_last_batch_result(launch_done) - if batch.next_batch_sampling_info: - batch.next_batch_sampling_info.update_regex_vocab_mask() - self.current_stream.synchronize() - batch.next_batch_sampling_info.sampling_info_done.set() + self.set_next_batch_sampling_info_done(batch) elif batch.forward_mode.is_dummy_first(): - batch.next_batch_sampling_info.update_regex_vocab_mask() - self.current_stream.synchronize() - batch.next_batch_sampling_info.sampling_info_done.set() + self.set_next_batch_sampling_info_done(batch) if self.return_health_check_ct: # Return some signal for the health check. @@ -1776,6 +1764,13 @@ class Scheduler( self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] + def set_next_batch_sampling_info_done(self, batch: ScheduleBatch): + if batch.next_batch_sampling_info: + if batch.next_batch_sampling_info.grammars is not None: + batch.next_batch_sampling_info.update_regex_vocab_mask() + self.current_stream.synchronize() + batch.next_batch_sampling_info.sampling_info_done.set() + def watchdog_thread(self): """A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" self.watchdog_last_forward_ct = 0 diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index d3b7c6f8e..d2a450aec 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1,8 +1,11 @@ from __future__ import annotations +import logging import threading +import time from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch @@ -15,6 +18,8 @@ if TYPE_CHECKING: Scheduler, ) +logger = logging.getLogger(__name__) + DEFAULT_FORCE_STREAM_INTERVAL = 50 @@ -83,6 +88,7 @@ class SchedulerOutputProcessorMixin: if req.finished(): self.tree_cache.cache_finished_req(req) + req.time_stats.completion_time = time.time() elif not batch.decoding_reqs or req not in batch.decoding_reqs: # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) @@ -149,10 +155,7 @@ class SchedulerOutputProcessorMixin: ) logprob_pt += num_input_logprobs - if batch.next_batch_sampling_info: - batch.next_batch_sampling_info.update_regex_vocab_mask() - self.current_stream.synchronize() - batch.next_batch_sampling_info.sampling_info_done.set() + self.set_next_batch_sampling_info_done(batch) else: # embedding or reward model embeddings, bid = result.embeddings, result.bid @@ -233,6 +236,7 @@ class SchedulerOutputProcessorMixin: req.check_finished() if req.finished(): self.tree_cache.cache_finished_req(req) + req.time_stats.completion_time = time.time() if req.return_logprob and batch.spec_algorithm.is_none(): # speculative worker handles logprob in speculative decoding @@ -262,13 +266,8 @@ class SchedulerOutputProcessorMixin: req.grammar.accept_token(next_token_id) req.grammar.finished = req.finished() - if batch.next_batch_sampling_info: - batch.next_batch_sampling_info.update_regex_vocab_mask() - self.current_stream.synchronize() - batch.next_batch_sampling_info.sampling_info_done.set() - + self.set_next_batch_sampling_info_done(batch) self.stream_output(batch.reqs, batch.return_logprob) - self.token_to_kv_pool_allocator.free_group_end() self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) @@ -530,16 +529,27 @@ class SchedulerOutputProcessorMixin: ) if should_output: + send_token_offset = req.send_token_offset + send_output_token_logprobs_offset = ( + req.send_output_token_logprobs_offset + ) rids.append(req.rid) finished_reasons.append( req.finished_reason.to_json() if req.finished_reason else None ) decoded_texts.append(req.decoded_text) decode_ids, read_offset = req.init_incremental_detokenize() - decode_ids_list.append(decode_ids) + + if self.model_config.is_multimodal_gen: + decode_ids_list.append(decode_ids) + else: + decode_ids_list.append(decode_ids[req.send_decode_id_offset :]) + + req.send_decode_id_offset = len(decode_ids) read_offsets.append(read_offset) if self.skip_tokenizer_init: - output_ids.append(req.output_ids) + output_ids.append(req.output_ids[send_token_offset:]) + req.send_token_offset = len(req.output_ids) skip_special_tokens.append(req.sampling_params.skip_special_tokens) spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens @@ -553,36 +563,90 @@ class SchedulerOutputProcessorMixin: spec_verify_ct.append(req.spec_verify_ct) if return_logprob: - input_token_logprobs_val.append(req.input_token_logprobs_val) - input_token_logprobs_idx.append(req.input_token_logprobs_idx) - output_token_logprobs_val.append(req.output_token_logprobs_val) - output_token_logprobs_idx.append(req.output_token_logprobs_idx) - input_top_logprobs_val.append(req.input_top_logprobs_val) - input_top_logprobs_idx.append(req.input_top_logprobs_idx) - output_top_logprobs_val.append(req.output_top_logprobs_val) - output_top_logprobs_idx.append(req.output_top_logprobs_idx) - input_token_ids_logprobs_val.append( - req.input_token_ids_logprobs_val - ) - input_token_ids_logprobs_idx.append( - req.input_token_ids_logprobs_idx - ) - output_token_ids_logprobs_val.append( - req.output_token_ids_logprobs_val - ) - output_token_ids_logprobs_idx.append( - req.output_token_ids_logprobs_idx - ) + if ( + req.return_logprob + and not req.input_logprob_sent + # Decode server does not send input logprobs + and self.disaggregation_mode != DisaggregationMode.DECODE + ): + input_token_logprobs_val.append(req.input_token_logprobs_val) + input_token_logprobs_idx.append(req.input_token_logprobs_idx) + input_top_logprobs_val.append(req.input_top_logprobs_val) + input_top_logprobs_idx.append(req.input_top_logprobs_idx) + input_token_ids_logprobs_val.append( + req.input_token_ids_logprobs_val + ) + input_token_ids_logprobs_idx.append( + req.input_token_ids_logprobs_idx + ) + req.input_logprob_sent = True + else: + input_token_logprobs_val.append([]) + input_token_logprobs_idx.append([]) + input_top_logprobs_val.append([]) + input_top_logprobs_idx.append([]) + input_token_ids_logprobs_val.append([]) + input_token_ids_logprobs_idx.append([]) + + if req.return_logprob: + output_token_logprobs_val.append( + req.output_token_logprobs_val[ + send_output_token_logprobs_offset: + ] + ) + output_token_logprobs_idx.append( + req.output_token_logprobs_idx[ + send_output_token_logprobs_offset: + ] + ) + output_top_logprobs_val.append( + req.output_top_logprobs_val[ + send_output_token_logprobs_offset: + ] + ) + output_top_logprobs_idx.append( + req.output_top_logprobs_idx[ + send_output_token_logprobs_offset: + ] + ) + output_token_ids_logprobs_val.append( + req.output_token_ids_logprobs_val[ + send_output_token_logprobs_offset: + ] + ) + output_token_ids_logprobs_idx.append( + req.output_token_ids_logprobs_idx[ + send_output_token_logprobs_offset: + ] + ) + req.send_output_token_logprobs_offset = len( + req.output_token_logprobs_val + ) + else: + output_token_logprobs_val.append([]) + output_token_logprobs_idx.append([]) + output_top_logprobs_val.append([]) + output_top_logprobs_idx.append([]) + output_token_ids_logprobs_val.append([]) + output_token_ids_logprobs_idx.append([]) if req.return_hidden_states: if output_hidden_states is None: output_hidden_states = [] output_hidden_states.append(req.hidden_states) + if ( + req.finished() + and self.tp_rank == 0 + and self.server_args.enable_request_time_stats_logging + ): + req.log_time_stats() + # Send to detokenizer if rids: if self.model_config.is_multimodal_gen: return + self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( rids, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index db64dd0a2..b646fae1c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -125,10 +125,10 @@ logger = logging.getLogger(__name__) class ReqState: """Store the state a request.""" - out_list: List + out_list: List[Dict[Any, Any]] finished: bool event: asyncio.Event - obj: Any + obj: Union[GenerateReqInput, EmbeddingReqInput] # For metrics created_time: float @@ -139,6 +139,21 @@ class ReqState: # For streaming output last_output_offset: int = 0 + # For incremental state update. + text: str = "" + output_ids: List[int] = dataclasses.field(default_factory=list) + input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) + input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) + output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) + output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) + input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) + input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) + output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) + input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) + input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) + output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) + output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) class TokenizerManager: @@ -1065,9 +1080,11 @@ class TokenizerManager: if getattr(state.obj, "return_logprob", False): self.convert_logprob_style( meta_info, + state, state.obj.top_logprobs_num, state.obj.token_ids_logprob, - state.obj.return_text_in_logprobs, + state.obj.return_text_in_logprobs + and not self.server_args.skip_tokenizer_init, recv_obj, i, ) @@ -1084,18 +1101,19 @@ class TokenizerManager: meta_info["hidden_states"] = recv_obj.output_hidden_states[i] if isinstance(recv_obj, BatchStrOut): + state.text += recv_obj.output_strs[i] out_dict = { - "text": recv_obj.output_strs[i], + "text": state.text, "meta_info": meta_info, } elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: - output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] - state.last_output_offset = len(recv_obj.output_ids[i]) + state.output_ids.extend(recv_obj.output_ids[i]) + output_token_ids = state.output_ids[state.last_output_offset :] + state.last_output_offset = len(state.output_ids) else: - output_token_ids = recv_obj.output_ids[i] + state.output_ids.extend(recv_obj.output_ids[i]) + output_token_ids = state.output_ids out_dict = { "output_ids": output_token_ids, @@ -1130,45 +1148,85 @@ class TokenizerManager: def convert_logprob_style( self, meta_info: dict, + state: ReqState, top_logprobs_num: int, token_ids_logprob: List[int], return_text_in_logprobs: bool, recv_obj: BatchStrOut, recv_obj_index: int, ): + if len(recv_obj.input_token_logprobs_val) > 0: + state.input_token_logprobs_val.extend( + recv_obj.input_token_logprobs_val[recv_obj_index] + ) + state.input_token_logprobs_idx.extend( + recv_obj.input_token_logprobs_idx[recv_obj_index] + ) + state.output_token_logprobs_val.extend( + recv_obj.output_token_logprobs_val[recv_obj_index] + ) + state.output_token_logprobs_idx.extend( + recv_obj.output_token_logprobs_idx[recv_obj_index] + ) meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( - recv_obj.input_token_logprobs_val[recv_obj_index], - recv_obj.input_token_logprobs_idx[recv_obj_index], + state.input_token_logprobs_val, + state.input_token_logprobs_idx, return_text_in_logprobs, ) meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( - recv_obj.output_token_logprobs_val[recv_obj_index], - recv_obj.output_token_logprobs_idx[recv_obj_index], + state.output_token_logprobs_val, + state.output_token_logprobs_idx, return_text_in_logprobs, ) if top_logprobs_num > 0: + if len(recv_obj.input_top_logprobs_val) > 0: + state.input_top_logprobs_val.extend( + recv_obj.input_top_logprobs_val[recv_obj_index] + ) + state.input_top_logprobs_idx.extend( + recv_obj.input_top_logprobs_idx[recv_obj_index] + ) + state.output_top_logprobs_val.extend( + recv_obj.output_top_logprobs_val[recv_obj_index] + ) + state.output_top_logprobs_idx.extend( + recv_obj.output_top_logprobs_idx[recv_obj_index] + ) meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( - recv_obj.input_top_logprobs_val[recv_obj_index], - recv_obj.input_top_logprobs_idx[recv_obj_index], + state.input_top_logprobs_val, + state.input_top_logprobs_idx, return_text_in_logprobs, ) meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( - recv_obj.output_top_logprobs_val[recv_obj_index], - recv_obj.output_top_logprobs_idx[recv_obj_index], + state.output_top_logprobs_val, + state.output_top_logprobs_idx, return_text_in_logprobs, ) if token_ids_logprob is not None: + if len(recv_obj.input_token_ids_logprobs_val) > 0: + state.input_token_ids_logprobs_val.extend( + recv_obj.input_token_ids_logprobs_val[recv_obj_index] + ) + state.input_token_ids_logprobs_idx.extend( + recv_obj.input_token_ids_logprobs_idx[recv_obj_index] + ) + state.output_token_ids_logprobs_val.extend( + recv_obj.output_token_ids_logprobs_val[recv_obj_index] + ) + state.output_token_ids_logprobs_idx.extend( + recv_obj.output_token_ids_logprobs_idx[recv_obj_index] + ) meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens( - recv_obj.input_token_ids_logprobs_val[recv_obj_index], - recv_obj.input_token_ids_logprobs_idx[recv_obj_index], + state.input_token_ids_logprobs_val, + state.input_token_ids_logprobs_idx, return_text_in_logprobs, ) meta_info["output_token_ids_logprobs"] = ( self.detokenize_top_logprobs_tokens( - recv_obj.output_token_ids_logprobs_val[recv_obj_index], - recv_obj.output_token_ids_logprobs_idx[recv_obj_index], + state.output_token_ids_logprobs_val, + state.output_token_ids_logprobs_idx, return_text_in_logprobs, ) ) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 4c6cd576f..783d864ea 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -127,10 +127,12 @@ class TpModelWorkerClient: batch_lists = [None] * 2 while True: - model_worker_batch, future_token_ids_ct = self.input_queue.get() + model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get() if not model_worker_batch: break + sync_event.wait() + # Keep a reference of model_worker_batch by storing it into a list. # Otherwise, the tensor members of model_worker_batch will be released # by pytorch and cause CUDA illegal memory access errors. @@ -214,10 +216,11 @@ class TpModelWorkerClient: ) # A cuda stream sync here to avoid the cuda illegal memory access error. - self.scheduler_stream.synchronize() + sync_event = torch.get_device_module(self.device).Event() + sync_event.record(self.scheduler_stream) # Push a new batch to the queue - self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) + self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event)) # Allocate output future objects bs = len(model_worker_batch.seq_lens) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6499d4c6b..7f169ef04 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -307,5 +307,5 @@ class SamplingBatchInfo: other_val = getattr(other, item, None) setattr(self, item, torch.cat([self_val, other_val])) - self.is_all_greedy |= other.is_all_greedy + self.is_all_greedy &= other.is_all_greedy self.need_min_p_sampling |= other.need_min_p_sampling diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6051d2409..5787ddfd2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -98,6 +98,7 @@ class ServerArgs: show_time_cost: bool = False enable_metrics: bool = False decode_log_interval: int = 40 + enable_request_time_stats_logging: bool = False # API related api_key: Optional[str] = None @@ -785,6 +786,12 @@ class ServerArgs: default=ServerArgs.decode_log_interval, help="The log interval of decode batch.", ) + parser.add_argument( + "--enable-request-time-stats-logging", + action="store_true", + default=ServerArgs.enable_request_time_stats_logging, + help="Enable per request time stats logging", + ) # API related parser.add_argument(