From 9087694006c44a166b3332c6c5f6a15f84a6daae Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 22 Feb 2025 11:50:46 +0800 Subject: [PATCH] Improve: Use TypeBasedDispatcher in DetokenizerManager (#3117) --- .../srt/managers/detokenizer_manager.py | 213 +++++++++--------- 1 file changed, 112 insertions(+), 101 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 1092cb30e..aa5c6dba8 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -32,7 +32,11 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import configure_logger, get_zmq_socket -from sglang.utils import find_printable_text, get_exception_traceback +from sglang.utils import ( + TypeBasedDispatcher, + find_printable_text, + get_exception_traceback, +) logger = logging.getLogger(__name__) @@ -83,6 +87,13 @@ class DetokenizerManager: self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) + self._request_dispatcher = TypeBasedDispatcher( + [ + (BatchEmbeddingOut, self.handle_batch_embedding_out), + (BatchTokenIDOut, self.handle_batch_token_id_out), + ] + ) + def trim_matched_stop( self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool ): @@ -111,115 +122,115 @@ class DetokenizerManager: while True: recv_obj = self.recv_from_scheduler.recv_pyobj() + output = self._request_dispatcher(recv_obj) + self.send_to_tokenizer.send_pyobj(output) - if isinstance(recv_obj, BatchEmbeddingOut): - # If it is embedding model, no detokenization is needed. - self.send_to_tokenizer.send_pyobj(recv_obj) - continue + def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut): + # If it is embedding model, no detokenization is needed. + return recv_obj + + def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut): + bs = len(recv_obj.rids) + + # Initialize decode status + read_ids, surr_ids = [], [] + for i in range(bs): + rid = recv_obj.rids[i] + vid = recv_obj.vids[i] + if rid not in self.decode_status or self.decode_status[rid].vid != vid: + s = DecodeStatus( + vid=vid, + decoded_text=recv_obj.decoded_texts[i], + decode_ids=recv_obj.decode_ids[i], + surr_offset=0, + read_offset=recv_obj.read_offsets[i], + ) + self.decode_status[rid] = s else: - assert isinstance(recv_obj, BatchTokenIDOut) + s = self.decode_status[rid] + s.decode_ids = recv_obj.decode_ids[i] - bs = len(recv_obj.rids) + read_ids.append( + self.trim_matched_stop( + s.decode_ids[s.surr_offset :], + recv_obj.finished_reasons[i], + recv_obj.no_stop_trim[i], + ) + ) + surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) - # Initialize decode status - read_ids, surr_ids = [], [] - for i in range(bs): - rid = recv_obj.rids[i] - vid = recv_obj.vids[i] - if rid not in self.decode_status or self.decode_status[rid].vid != vid: - s = DecodeStatus( - vid=vid, - decoded_text=recv_obj.decoded_texts[i], - decode_ids=recv_obj.decode_ids[i], - surr_offset=0, - read_offset=recv_obj.read_offsets[i], - ) - self.decode_status[rid] = s + # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request + surr_texts = self.tokenizer.batch_decode( + surr_ids, + skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], + ) + read_texts = self.tokenizer.batch_decode( + read_ids, + skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], + ) + + # Incremental decoding + output_strs = [] + finished_reqs = [] + for i in range(bs): + try: + s = self.decode_status[recv_obj.rids[i]] + except KeyError: + raise RuntimeError( + f"Decode status not found for request {recv_obj.rids[i]}. " + "It may be due to the request being evicted from the decode status due to memory pressure. " + "Please increase the maximum number of requests by setting " + "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " + f"The current value is {DETOKENIZER_MAX_STATES}. " + "For more details, see: https://github.com/sgl-project/sglang/issues/2812" + ) + new_text = read_texts[i][len(surr_texts[i]) :] + if recv_obj.finished_reasons[i] is None: + # Streaming chunk: update the decode status + if len(new_text) > 0 and not new_text.endswith("�"): + s.decoded_text = s.decoded_text + new_text + s.surr_offset = s.read_offset + s.read_offset = len(s.decode_ids) + new_text = "" else: - s = self.decode_status[rid] - s.decode_ids = recv_obj.decode_ids[i] + new_text = find_printable_text(new_text) + else: + finished_reqs.append(recv_obj.rids[i]) - read_ids.append( - self.trim_matched_stop( - s.decode_ids[s.surr_offset :], - recv_obj.finished_reasons[i], - recv_obj.no_stop_trim[i], - ) - ) - surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) - - # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request - surr_texts = self.tokenizer.batch_decode( - surr_ids, - skip_special_tokens=recv_obj.skip_special_tokens[0], - spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], - ) - read_texts = self.tokenizer.batch_decode( - read_ids, - skip_special_tokens=recv_obj.skip_special_tokens[0], - spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], - ) - - # Incremental decoding - output_strs = [] - finished_reqs = [] - for i in range(bs): - try: - s = self.decode_status[recv_obj.rids[i]] - except KeyError: - raise RuntimeError( - f"Decode status not found for request {recv_obj.rids[i]}. " - "It may be due to the request being evicted from the decode status due to memory pressure. " - "Please increase the maximum number of requests by setting " - "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " - f"The current value is {DETOKENIZER_MAX_STATES}. " - "For more details, see: https://github.com/sgl-project/sglang/issues/2812" - ) - new_text = read_texts[i][len(surr_texts[i]) :] - if recv_obj.finished_reasons[i] is None: - # Streaming chunk: update the decode status - if len(new_text) > 0 and not new_text.endswith("�"): - s.decoded_text = s.decoded_text + new_text - s.surr_offset = s.read_offset - s.read_offset = len(s.decode_ids) - new_text = "" - else: - new_text = find_printable_text(new_text) - else: - finished_reqs.append(recv_obj.rids[i]) - - output_strs.append( - self.trim_matched_stop( - s.decoded_text + new_text, - recv_obj.finished_reasons[i], - recv_obj.no_stop_trim[i], - ) - ) - - self.send_to_tokenizer.send_pyobj( - BatchStrOut( - rids=recv_obj.rids, - finished_reasons=recv_obj.finished_reasons, - output_strs=output_strs, - prompt_tokens=recv_obj.prompt_tokens, - completion_tokens=recv_obj.completion_tokens, - cached_tokens=recv_obj.cached_tokens, - spec_verify_ct=recv_obj.spec_verify_ct, - input_token_logprobs_val=recv_obj.input_token_logprobs_val, - input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, - output_token_logprobs_val=recv_obj.output_token_logprobs_val, - output_token_logprobs_idx=recv_obj.output_token_logprobs_idx, - input_top_logprobs_val=recv_obj.input_top_logprobs_val, - input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, - output_top_logprobs_val=recv_obj.output_top_logprobs_val, - output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, - output_hidden_states=recv_obj.output_hidden_states, + output_strs.append( + self.trim_matched_stop( + s.decoded_text + new_text, + recv_obj.finished_reasons[i], + recv_obj.no_stop_trim[i], ) ) - # remove decodestatus for completed requests - for rid in finished_reqs: - self.decode_status.pop(rid) + out = BatchStrOut( + rids=recv_obj.rids, + finished_reasons=recv_obj.finished_reasons, + output_strs=output_strs, + prompt_tokens=recv_obj.prompt_tokens, + completion_tokens=recv_obj.completion_tokens, + cached_tokens=recv_obj.cached_tokens, + spec_verify_ct=recv_obj.spec_verify_ct, + input_token_logprobs_val=recv_obj.input_token_logprobs_val, + input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, + output_token_logprobs_val=recv_obj.output_token_logprobs_val, + output_token_logprobs_idx=recv_obj.output_token_logprobs_idx, + input_top_logprobs_val=recv_obj.input_top_logprobs_val, + input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, + output_top_logprobs_val=recv_obj.output_top_logprobs_val, + output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + output_hidden_states=recv_obj.output_hidden_states, + ) + + # remove decodestatus for completed requests + for rid in finished_reqs: + self.decode_status.pop(rid) + + return out class LimitedCapacityDict(OrderedDict):