diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 274c4c311..915cb47d2 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -39,10 +39,12 @@ class LogitsProcessorOutput: # The logprobs of input tokens. shape: [#token, vocab_size] input_token_logprobs: torch.Tensor = None - # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - input_top_logprobs: List = None - # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - output_top_logprobs: List = None + # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] + input_top_logprobs_val: List = None + input_top_logprobs_idx: List = None + # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] + output_top_logprobs_val: List = None + output_top_logprobs_idx: List = None @dataclasses.dataclass @@ -125,12 +127,15 @@ class LogitsProcessor(nn.Module): indices = ret.indices.tolist() if logits_metadata.forward_mode.is_decode(): - output_top_logprobs = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] for i, k in enumerate(logits_metadata.top_logprobs_nums): - output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) - return None, output_top_logprobs + output_top_logprobs_val.append(values[i][:k]) + output_top_logprobs_idx.append(indices[i][:k]) + return None, None, output_top_logprobs_val, output_top_logprobs_idx else: - input_top_logprobs, output_top_logprobs = [], [] + input_top_logprobs_val, input_top_logprobs_idx = [], [] + output_top_logprobs_val, output_top_logprobs_idx = [], [] pt = 0 for k, pruned_len in zip( @@ -138,27 +143,36 @@ class LogitsProcessor(nn.Module): logits_metadata.extend_logprob_pruned_lens_cpu, ): if pruned_len <= 0: - input_top_logprobs.append([]) - output_top_logprobs.append([]) + input_top_logprobs_val.append([]) + input_top_logprobs_idx.append([]) + output_top_logprobs_val.append([]) + output_top_logprobs_idx.append([]) continue - input_top_logprobs.append( - [ - list(zip(values[pt + j][:k], indices[pt + j][:k])) - for j in range(pruned_len - 1) - ] + input_top_logprobs_val.append( + [values[pt + j][:k] for j in range(pruned_len - 1)] ) - output_top_logprobs.append( + input_top_logprobs_idx.append( + [indices[pt + j][:k] for j in range(pruned_len - 1)] + ) + output_top_logprobs_val.append( list( - zip( - values[pt + pruned_len - 1][:k], - indices[pt + pruned_len - 1][:k], - ) + values[pt + pruned_len - 1][:k], + ) + ) + output_top_logprobs_idx.append( + list( + indices[pt + pruned_len - 1][:k], ) ) pt += pruned_len - return input_top_logprobs, output_top_logprobs + return ( + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + ) def forward( self, @@ -193,29 +207,22 @@ class LogitsProcessor(nn.Module): if not logits_metadata.return_logprob: return LogitsProcessorOutput( next_token_logits=last_logits, - next_token_logprobs=None, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=None, ) else: last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) if logits_metadata.forward_mode.is_decode(): if logits_metadata.return_top_logprob: - output_top_logprobs = self.get_top_logprobs( - last_logprobs, logits_metadata - )[1] + output_top_logprobs_val, output_top_logprobs_idx = ( + self.get_top_logprobs(last_logprobs, logits_metadata)[2:4] + ) else: - output_top_logprobs = None + output_top_logprobs_val = output_top_logprobs_idx = None return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=output_top_logprobs, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, ) else: # Slice the requested tokens to compute logprob @@ -246,11 +253,16 @@ class LogitsProcessor(nn.Module): # Get the logprob of top-k tokens if logits_metadata.return_top_logprob: - input_top_logprobs, output_top_logprobs = self.get_top_logprobs( - all_logprobs, logits_metadata - ) + ( + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + ) = self.get_top_logprobs(all_logprobs, logits_metadata) else: - input_top_logprobs = output_top_logprobs = None + input_top_logprobs_val = input_top_logprobs_idx = ( + output_top_logprobs_val + ) = output_top_logprobs_idx = None # Compute the normalized logprobs for the requested tokens. # Note that we pad a zero at the end for easy batching. @@ -273,8 +285,10 @@ class LogitsProcessor(nn.Module): next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, input_token_logprobs=input_token_logprobs, - input_top_logprobs=input_top_logprobs, - output_top_logprobs=output_top_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, ) def _get_logits( diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 120e990da..bc9e4a53b 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -17,7 +17,7 @@ import dataclasses import logging import signal from collections import OrderedDict -from typing import List, Union +from typing import Dict, List, Union import psutil import setproctitle @@ -76,17 +76,25 @@ class DetokenizerManager: self.decode_status = LimitedCapacityDict() - def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim): - if no_stop_trim: + def trim_matched_stop( + self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool + ): + if no_stop_trim or not finished_reason: return output - # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit - if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str): - pos = output.find(finished_reason.matched) + matched = finished_reason.get("matched", None) + if not matched: + return output + + # TODO(lmzheng): handle the case where multiple stop strs are hit + + # Trim stop str. + if isinstance(matched, str) and isinstance(output, str): + pos = output.find(matched) return output[:pos] if pos != -1 else output - if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance( - output, list - ): + + # Trim stop token. + if isinstance(matched, int) and isinstance(output, list): assert len(output) > 0 return output[:-1] return output @@ -125,9 +133,9 @@ class DetokenizerManager: s.decode_ids = recv_obj.decode_ids[i] read_ids.append( - self.trim_eos( + self.trim_matched_stop( s.decode_ids[s.surr_offset :], - recv_obj.finished_reason[i], + recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) ) @@ -150,7 +158,7 @@ class DetokenizerManager: for i in range(bs): s = self.decode_status[recv_obj.rids[i]] new_text = read_texts[i][len(surr_texts[i]) :] - if recv_obj.finished_reason[i] is None: + if recv_obj.finished_reasons[i] is None: # Streaming chunk: update the decode status if len(new_text) > 0 and not new_text.endswith("�"): s.decoded_text = s.decoded_text + new_text @@ -161,9 +169,9 @@ class DetokenizerManager: new_text = find_printable_text(new_text) output_strs.append( - self.trim_eos( + self.trim_matched_stop( s.decoded_text + new_text, - recv_obj.finished_reason[i], + recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) ) @@ -171,9 +179,20 @@ class DetokenizerManager: self.send_to_tokenizer.send_pyobj( BatchStrOut( rids=recv_obj.rids, + finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, - meta_info=recv_obj.meta_info, - finished_reason=recv_obj.finished_reason, + prompt_tokens=recv_obj.prompt_tokens, + completion_tokens=recv_obj.completion_tokens, + cached_tokens=recv_obj.cached_tokens, + input_token_logprobs_val=recv_obj.input_token_logprobs_val, + input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, + output_token_logprobs_val=recv_obj.output_token_logprobs_val, + output_token_logprobs_idx=recv_obj.output_token_logprobs_idx, + input_top_logprobs_val=recv_obj.input_top_logprobs_val, + input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, + output_top_logprobs_val=recv_obj.output_top_logprobs_val, + output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + normalized_prompt_logprob=recv_obj.normalized_prompt_logprob, ) ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 27bf5a4bd..c5884b5f0 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput: class BatchTokenIDOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[BaseFinishReason] + # For incremental decoding # The version id to sync decode status with in detokenizer_manager vids: List[int] decoded_texts: List[str] @@ -315,35 +318,61 @@ class BatchTokenIDOut: read_offsets: List[int] # Only used when `--skip-tokenizer-init` output_ids: Optional[List[int]] + # Detokenization configs skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] - meta_info: List[Dict] - finished_reason: List[BaseFinishReason] no_stop_trim: List[bool] + # Token counts + prompt_tokens: List[int] + completion_tokens: List[int] + cached_tokens: List[int] + # Logprobs + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + input_top_logprobs_val: List[List] + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] + normalized_prompt_logprob: List[float] @dataclass class BatchStrOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[dict] # The output decoded strings output_strs: List[str] - # The meta info - meta_info: List[Dict] - # The finish reason - finished_reason: List[BaseFinishReason] + + # Token counts + prompt_tokens: List[int] + completion_tokens: List[int] + cached_tokens: List[int] + # Logprobs + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + input_top_logprobs_val: List[List] + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] + normalized_prompt_logprob: List[float] @dataclass class BatchEmbeddingOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[BaseFinishReason] # The output embedding embeddings: List[List[float]] - # The meta info - meta_info: List[Dict] - # The finish reason - finished_reason: List[BaseFinishReason] + # Token counts + prompt_tokens: List[int] @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5855d4248..bb9eb1816 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -200,6 +200,9 @@ class Req: origin_input_text: str, origin_input_ids: Tuple[int], sampling_params: SamplingParams, + return_logprob: bool = False, + top_logprobs_num: int = 0, + stream: bool = False, origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, @@ -217,10 +220,11 @@ class Req: self.output_ids = [] # Each decode stage's output ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.session_id = session_id + self.input_embeds = input_embeds + # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path - self.input_embeds = input_embeds # Memory pool info self.req_pool_idx = None @@ -228,8 +232,8 @@ class Req: # Check finish self.tokenizer = None self.finished_reason = None - self.stream = False self.to_abort = False + self.stream = stream # For incremental decoding # ----- | --------- read_ids -------| @@ -241,13 +245,9 @@ class Req: # 2: read_offset # 3: last token self.vid = 0 # version id to sync decode status with in detokenizer_manager - self.decoded_text = "" self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None - - # The number of decoded tokens for token usage report. Note that - # this does not include the jump forward tokens. - self.completion_tokens_wo_jump_forward = 0 + self.decoded_text = "" # For multimodal inputs self.image_inputs: Optional[ImageInputs] = None @@ -256,22 +256,34 @@ class Req: self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None + + # Chunked prefill self.is_being_chunked = 0 # For retraction self.is_retracted = False # Logprobs (arguments) - self.return_logprob = False + self.return_logprob = return_logprob self.logprob_start_len = 0 - self.top_logprobs_num = 0 + self.top_logprobs_num = top_logprobs_num # Logprobs (return value) self.normalized_prompt_logprob = None - self.input_token_logprobs = None - self.input_top_logprobs = None - self.output_token_logprobs = [] - self.output_top_logprobs = [] + self.input_token_logprobs_val = None + self.input_token_logprobs_idx = None + self.input_top_logprobs_val = None + self.input_top_logprobs_idx = None + + if return_logprob: + self.output_token_logprobs_val = [] + self.output_token_logprobs_idx = [] + self.output_top_logprobs_val = [] + self.output_top_logprobs_idx = [] + else: + self.output_token_logprobs_val = self.output_token_logprobs_idx = ( + self.output_top_logprobs_val + ) = self.output_top_logprobs_idx = None # Logprobs (internal values) # The tokens is prefilled but need to be considered as decode tokens @@ -295,8 +307,8 @@ class Req: else: self.image_inputs.merge(image_inputs) - # whether request reached finished condition def finished(self) -> bool: + # Whether request reached finished condition return self.finished_reason is not None def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): @@ -454,8 +466,10 @@ class Req: k = k + 1 else: break - self.output_token_logprobs = self.output_token_logprobs[:k] - self.output_top_logprobs = self.output_top_logprobs[:k] + self.output_token_logprobs_val = self.output_token_logprobs_val[:k] + self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k] + self.output_top_logprobs_val = self.output_top_logprobs_val[:k] + self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k] self.logprob_start_len = prompt_tokens + k self.last_update_decode_tokens = len(self.output_ids) - k diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a3316503b..4ece87868 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -515,6 +515,9 @@ class Scheduler: recv_req.input_text, recv_req.input_ids, recv_req.sampling_params, + return_logprob=recv_req.return_logprob, + top_logprobs_num=recv_req.top_logprobs_num, + stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, ) @@ -558,9 +561,6 @@ class Scheduler: return # Copy more attributes - req.return_logprob = recv_req.return_logprob - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream req.logprob_start_len = recv_req.logprob_start_len if req.logprob_start_len == -1: @@ -982,7 +982,6 @@ class Scheduler: continue if req.is_being_chunked <= 0: - req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -1035,7 +1034,7 @@ class Scheduler: # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 - self.stream_output(batch.reqs, skip_stream_req) + self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids, bid = result @@ -1065,7 +1064,6 @@ class Scheduler: self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue - req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -1073,11 +1071,15 @@ class Scheduler: self.tree_cache.cache_finished_req(req) if req.return_logprob: - req.output_token_logprobs.append( - (next_token_logprobs[i], next_token_id) - ) + req.output_token_logprobs_val.append(next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_id) if req.top_logprobs_num > 0: - req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + req.output_top_logprobs_val.append( + logits_output.output_top_logprobs_val[i] + ) + req.output_top_logprobs_idx.append( + logits_output.output_top_logprobs_idx[i] + ) if req.grammar is not None: req.grammar.accept_token(next_token_id) @@ -1088,7 +1090,7 @@ class Scheduler: self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - self.stream_output(batch.reqs) + self.stream_output(batch.reqs, batch.return_logprob) self.token_to_kv_pool.free_group_end() @@ -1108,9 +1110,8 @@ class Scheduler: output: LogitsProcessorOutput, ): """Attach logprobs to the return values.""" - req.output_token_logprobs.append( - (output.next_token_logprobs[i], next_token_ids[i]) - ) + req.output_token_logprobs_val.append(output.next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_ids[i]) # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len @@ -1118,173 +1119,195 @@ class Scheduler: if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - if req.input_token_logprobs is None: - input_token_logprobs = output.input_token_logprobs[ + if req.input_token_logprobs_val is None: + input_token_logprobs_val = output.input_token_logprobs[ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens ] - input_token_ids = req.fill_ids[ + + input_token_logprobs_idx = req.fill_ids[ len(req.fill_ids) - num_input_logprobs + 1 : len(req.fill_ids) - req.last_update_decode_tokens ] - # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. - input_token_ids = [ + input_token_logprobs_idx = [ x if x < self.model_config.vocab_size - 1 else 0 - for x in input_token_ids + for x in input_token_logprobs_idx ] - req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids)) - if ( req.logprob_start_len == 0 ): # The first token does not have logprob, pad it. - req.input_token_logprobs = [ - (None, req.fill_ids[0]) - ] + req.input_token_logprobs + input_token_logprobs_val = [None] + input_token_logprobs_val + input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx + + req.input_token_logprobs_val = input_token_logprobs_val + req.input_token_logprobs_idx = input_token_logprobs_idx if req.last_update_decode_tokens != 0: # Some decode tokens are re-computed in an extend batch - req.output_token_logprobs.extend( - list( - zip( - output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 - ], - req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) - ], - ) - ) + req.output_token_logprobs_val.extend( + output.input_token_logprobs[ + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens : pt + + num_input_logprobs + - 1 + ], + ) + req.output_token_logprobs_idx.extend( + req.fill_ids[ + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) + ] ) if req.top_logprobs_num > 0: - if req.input_top_logprobs is None: - req.input_top_logprobs = output.input_top_logprobs[i] + if req.input_top_logprobs_val is None: + req.input_top_logprobs_val = output.input_top_logprobs_val[i] + req.input_top_logprobs_idx = output.input_top_logprobs_idx[i] if req.logprob_start_len == 0: - req.input_top_logprobs = [None] + req.input_top_logprobs + req.input_top_logprobs_val = [None] + req.input_top_logprobs_val + req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx if req.last_update_decode_tokens != 0: - req.output_top_logprobs.extend( - output.input_top_logprobs[i][-req.last_update_decode_tokens :] + req.output_top_logprobs_val.extend( + output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] ) - req.output_top_logprobs.append(output.output_top_logprobs[i]) + req.output_top_logprobs_idx.extend( + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] + ) + req.output_top_logprobs_val.append(output.output_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i]) return num_input_logprobs - def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None): + def stream_output( + self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None + ): """Stream the output to detokenizer.""" - output_rids = [] - output_meta_info: List[dict] = [] - output_finished_reason: List[BaseFinishReason] = [] + rids = [] + finished_reasons: List[BaseFinishReason] = [] + if self.is_generation: - output_vids = [] + vids = [] decoded_texts = [] - output_read_ids = [] - output_read_offsets = [] + decode_ids_list = [] + read_offsets = [] output_ids = [] - output_skip_special_tokens = [] - output_spaces_between_special_tokens = [] - output_no_stop_trim = [] - else: # embedding or reward model - output_embeddings = [] + skip_special_tokens = [] + spaces_between_special_tokens = [] + no_stop_trim = [] + prompt_tokens = [] + completion_tokens = [] + cached_tokens = [] - is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 + if return_logprob: + input_token_logprobs_val = [] + input_token_logprobs_idx = [] + output_token_logprobs_val = [] + output_token_logprobs_idx = [] + input_top_logprobs_val = [] + input_top_logprobs_idx = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] + normalized_prompt_logprob = [] + else: + input_token_logprobs_val = input_token_logprobs_idx = ( + output_token_logprobs_val + ) = output_token_logprobs_idx = input_top_logprobs_val = ( + input_top_logprobs_idx + ) = output_top_logprobs_val = output_top_logprobs_idx = ( + normalized_prompt_logprob + ) = None - for req in reqs: - if req is skip_req: - continue + for req in reqs: + if req is skip_req: + continue - # TODO(lianmin): revisit this for overlap + retract + stream - if req.finished() or ( - req.stream and (is_stream_iter or len(req.output_ids) == 1) - ): - output_rids.append(req.rid) - output_finished_reason.append(req.finished_reason) - if self.is_generation: - output_vids.append(req.vid) + # TODO(lianmin): revisit this for overlap + retract + stream + if ( + req.finished() + # If stream, follow the given stream_interval + or (req.stream and len(req.output_ids) % self.stream_interval == 0) + # If not stream, we still want to output some tokens to get the benefit of incremental decoding. + or (not req.stream and len(req.output_ids) % 50 == 0) + ): + rids.append(req.rid) + finished_reasons.append( + req.finished_reason.to_json() if req.finished_reason else None + ) + vids.append(req.vid) decoded_texts.append(req.decoded_text) - read_ids, read_offset = req.init_incremental_detokenize() - output_read_ids.append(read_ids) - output_read_offsets.append(read_offset) + decode_ids, read_offset = req.init_incremental_detokenize() + decode_ids_list.append(decode_ids) + read_offsets.append(read_offset) if self.skip_tokenizer_init: output_ids.append(req.output_ids) - output_skip_special_tokens.append( - req.sampling_params.skip_special_tokens - ) - output_spaces_between_special_tokens.append( + skip_special_tokens.append(req.sampling_params.skip_special_tokens) + spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens ) - output_no_stop_trim.append(req.sampling_params.no_stop_trim) + no_stop_trim.append(req.sampling_params.no_stop_trim) - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - "completion_tokens": len(req.output_ids), - "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "cached_tokens": req.cached_tokens, - "finish_reason": ( - req.finished_reason.to_json() - if req.finished_reason is not None - else None - ), - } - if req.return_logprob: - ( - meta_info["input_token_logprobs"], - meta_info["output_token_logprobs"], - meta_info["input_top_logprobs"], - meta_info["output_top_logprobs"], - meta_info["normalized_prompt_logprob"], - ) = ( - req.input_token_logprobs, - req.output_token_logprobs, - req.input_top_logprobs, - req.output_top_logprobs, - req.normalized_prompt_logprob, - ) - output_meta_info.append(meta_info) - else: # embedding or reward model - output_embeddings.append(req.embedding) - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - } - output_meta_info.append(meta_info) + prompt_tokens.append(len(req.origin_input_ids)) + completion_tokens.append(len(req.output_ids)) + cached_tokens.append(req.cached_tokens) - # Send to detokenizer - if output_rids: - if self.is_generation: + if return_logprob: + input_token_logprobs_val.append(req.input_token_logprobs_val) + input_token_logprobs_idx.append(req.input_token_logprobs_idx) + output_token_logprobs_val.append(req.output_token_logprobs_val) + output_token_logprobs_idx.append(req.output_token_logprobs_idx) + input_top_logprobs_val.append(req.input_top_logprobs_val) + input_top_logprobs_idx.append(req.input_top_logprobs_idx) + output_top_logprobs_val.append(req.output_top_logprobs_val) + output_top_logprobs_idx.append(req.output_top_logprobs_idx) + normalized_prompt_logprob.append(req.normalized_prompt_logprob) + + # Send to detokenizer + if rids: self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( - output_rids, - output_vids, + rids, + finished_reasons, + vids, decoded_texts, - output_read_ids, - output_read_offsets, + decode_ids_list, + read_offsets, output_ids, - output_skip_special_tokens, - output_spaces_between_special_tokens, - output_meta_info, - output_finished_reason, - output_no_stop_trim, - ) - ) - else: # embedding or reward model - self.send_to_detokenizer.send_pyobj( - BatchEmbeddingOut( - output_rids, - output_embeddings, - output_meta_info, - output_finished_reason, + skip_special_tokens, + spaces_between_special_tokens, + no_stop_trim, + prompt_tokens, + completion_tokens, + cached_tokens, + input_token_logprobs_val, + input_token_logprobs_idx, + output_token_logprobs_val, + output_token_logprobs_idx, + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + normalized_prompt_logprob, ) ) + else: # embedding or reward model + embeddings = [] + prompt_tokens = [] + for req in reqs: + assert req.finished() + rids.append(req.rid) + finished_reasons.append(req.finished_reason.to_json()) + embeddings.append(req.embedding) + prompt_tokens.append(len(req.origin_input_ids)) + self.send_to_detokenizer.send_pyobj( + BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens) + ) def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # Check if other DP workers have running batches diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 56e01528a..4788565ac 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -22,7 +22,7 @@ import signal import sys import time import uuid -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import fastapi import uvloop @@ -76,6 +76,7 @@ class ReqState: out_list: List finished: bool event: asyncio.Event + obj: Any # For metrics created_time: float @@ -283,7 +284,7 @@ class TokenizerManager: ): """Wait for the response of one request.""" event = asyncio.Event() - state = ReqState([], False, event, created_time=created_time) + state = ReqState([], False, event, obj, created_time=created_time) self.rid_to_state[obj.rid] = state while True: @@ -295,15 +296,7 @@ class TokenizerManager: raise ValueError(f"Abort request {obj.rid}") continue - if isinstance(obj, GenerateReqInput): - out = self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, - obj.return_text_in_logprobs, - ) - else: # isinstance(obj, (EmbeddingReqInput,)) - out = state.out_list[-1] + out = state.out_list[-1] state.out_list = [] if state.finished: @@ -315,7 +308,13 @@ class TokenizerManager: break state.event.clear() - yield out + + if obj.stream: + yield out + else: + if request is not None and await request.is_disconnected(): + self.abort_request(obj.rid) + raise ValueError(f"Abort request {obj.rid}") async def _handle_batch_request( self, @@ -609,29 +608,55 @@ class TokenizerManager: if state is None: continue - recv_obj.meta_info[i]["id"] = rid + meta_info = { + "id": rid, + "finish_reason": recv_obj.finished_reasons[i], + "prompt_tokens": recv_obj.prompt_tokens[i], + } + + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( + meta_info, + state.obj.top_logprobs_num, + state.obj.return_text_in_logprobs, + recv_obj, + i, + ) + if isinstance(recv_obj, BatchStrOut): out_dict = { "text": recv_obj.output_strs[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": { + **meta_info, + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + }, } elif isinstance(recv_obj, BatchTokenIDOut): out_dict = { "token_ids": recv_obj.output_ids[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": { + **meta_info, + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + }, } else: assert isinstance(recv_obj, BatchEmbeddingOut) out_dict = { "embedding": recv_obj.embeddings[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": meta_info, } state.out_list.append(out_dict) - state.finished = recv_obj.finished_reason[i] is not None + state.finished = recv_obj.finished_reasons[i] is not None state.event.set() if self.enable_metrics: - completion_tokens = recv_obj.meta_info[i]["completion_tokens"] + completion_tokens = ( + recv_obj.completion_tokens[i] + if recv_obj.completion_tokens + else 0 + ) if state.first_token_time is None: state.first_token_time = time.time() @@ -647,7 +672,7 @@ class TokenizerManager: if state.finished: self.metrics_collector.inc_prompt_tokens( - recv_obj.meta_info[i]["prompt_tokens"] + recv_obj.prompt_tokens[i] ) self.metrics_collector.inc_generation_tokens( completion_tokens @@ -696,57 +721,73 @@ class TokenizerManager: def convert_logprob_style( self, - ret: dict, - return_logprob: bool, + meta_info: dict, top_logprobs_num: int, return_text_in_logprobs: bool, + recv_obj: BatchStrOut, + recv_obj_index: int, ): - if return_logprob: - ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs - ) - ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs - ) - - if top_logprobs_num > 0: - ret["meta_info"]["input_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["input_top_logprobs"], - return_text_in_logprobs, - ) - ) - ret["meta_info"]["output_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs - ) - ) - return ret - - def detokenize_logprob_tokens( - self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool - ): - # TODO(lianmin): This should run on DetokenizerManager - if not decode_to_text: - return [(logprob, token_id, None) for logprob, token_id in token_logprobs] - - assert self.tokenizer is not None - token_ids = [tid for _, tid in token_logprobs] - token_texts = self.tokenizer.batch_decode(token_ids) - return [ - (logprob, token_id, token_text) - for (logprob, token_id), token_text in zip(token_logprobs, token_texts) + meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.input_token_logprobs_val[recv_obj_index], + recv_obj.input_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.output_token_logprobs_val[recv_obj_index], + recv_obj.output_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[ + recv_obj_index ] - def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): + if top_logprobs_num > 0: + meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.input_top_logprobs_val[recv_obj_index], + recv_obj.input_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.output_top_logprobs_val[recv_obj_index], + recv_obj.output_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + + def detokenize_logprob_tokens( + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, + ): + if not decode_to_text: + return [ + (logprob, token_id, None) + for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx) + ] + else: + assert self.tokenizer is not None + token_texts = self.tokenizer.batch_decode(token_logprobs_idx) + return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) + + def detokenize_top_logprobs_tokens( + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, + ): # TODO: The current implementation only batches the detokenization for top-k tokens per single position. # We should batch all top-k tokens in all positions. - for i, token_top_logprobs in enumerate(top_logprobs): - if token_top_logprobs: - top_logprobs[i] = self.detokenize_logprob_tokens( - token_top_logprobs, decode_to_text + ret = [] + for i in range(len(token_logprobs_val)): + if token_logprobs_val[i]: + ret.append( + self.detokenize_logprob_tokens( + token_logprobs_val[i], token_logprobs_idx[i], decode_to_text + ) ) - return top_logprobs + else: + ret.append(None) + return ret class SignalHandler: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 27043cc9a..77efba892 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -400,9 +400,14 @@ class CudaGraphRunner: forward_mode=ForwardMode.DECODE, top_logprobs_nums=forward_batch.top_logprobs_nums, ) - logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( + ( + logits_output.output_top_logprobs_val, + logits_output.output_top_logprobs_idx, + ) = LogitsProcessor.get_top_logprobs( next_token_logprobs, logits_metadata - )[1] + )[ + 2:4 + ] else: logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 514bf31a6..32c6e08b6 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -720,13 +720,13 @@ def run_and_check_memory_leak( # Clean up everything kill_process_tree(process.pid) - kill_process_tree(process.pid) stdout.close() stderr.close() if os.path.exists(STDOUT_FILENAME): os.remove(STDOUT_FILENAME) if os.path.exists(STDERR_FILENAME): os.remove(STDERR_FILENAME) + kill_process_tree(process.pid) t.join() # Assert success @@ -734,7 +734,7 @@ def run_and_check_memory_leak( has_leak = False has_abort = False for line in output_lines: - if "The server is fired" in line: + if "Uvicorn running" in line: has_new_server = True if "leak" in line: has_leak = True diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 1a857d0da..adb5c18fb 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -95,15 +95,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase): self.assertIsInstance(js_obj["name"], str) self.assertIsInstance(js_obj["population"], int) - # Make sure jump forward is triggered - # NOTE: The overlap scheduler does not support jump forward so we only do this test - # when --disable-overlap-schedule is set. - if self.check_jump_forward: - self.assertGreater( - ret["meta_info"]["completion_tokens"], - ret["meta_info"]["completion_tokens_wo_jump_forward"], - ) - def test_json_generate(self): self.run_decode(json_schema=self.json_schema)