From 9cf0a5bada133d9f9f5bcc7f8f8cf0ba56848fb9 Mon Sep 17 00:00:00 2001 From: gryffindor-rr <107027757+gryffindor-rr@users.noreply.github.com> Date: Sat, 10 Aug 2024 03:14:13 +0800 Subject: [PATCH] Add skip_tokenizer_init args. (#959) Co-authored-by: lzhang --- python/sglang/srt/constrained/fsm_cache.py | 14 +++- .../srt/managers/detokenizer_manager.py | 18 +++-- python/sglang/srt/managers/schedule_batch.py | 17 ++-- .../sglang/srt/managers/tokenizer_manager.py | 74 ++++++++++++------ python/sglang/srt/managers/tp_worker.py | 50 +++++++----- python/sglang/srt/sampling_params.py | 12 ++- python/sglang/srt/server.py | 19 +++-- python/sglang/srt/server_args.py | 6 ++ python/sglang/srt/utils.py | 2 + test/srt/test_skip_tokenizer_srt.py | 77 +++++++++++++++++++ 10 files changed, 218 insertions(+), 71 deletions(-) create mode 100644 test/srt/test_skip_tokenizer_srt.py diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 6df6bec51..fa41f90de 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache class FSMCache(BaseToolCache): - def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): + def __init__( + self, + tokenizer_path, + tokenizer_args_dict, + enable=True, + skip_tokenizer_init=False, + ): super().__init__(enable=enable) - if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"): + if ( + skip_tokenizer_init + or tokenizer_path.endswith(".json") + or tokenizer_path.endswith(".model") + ): # Do not support TiktokenTokenizer or SentencePieceTokenizer return diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 623ffe916..d765a365f 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -59,11 +59,14 @@ class DetokenizerManager: self.send_to_tokenizer = context.socket(zmq.PUSH) self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}") - self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) + if server_args.skip_tokenizer_init: + self.tokenizer = None + else: + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) self.decode_status = {} @@ -85,6 +88,11 @@ class DetokenizerManager: assert isinstance(recv_obj, BatchTokenIDOut) bs = len(recv_obj.rids) + if self.tokenizer is None: + # Send BatchTokenIDOut if no tokenizer init'ed. + self.send_to_tokenizer.send_pyobj(recv_obj) + continue + # Initialize decode status read_ids, surr_ids = [], [] for i in range(bs): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d2101d2c0..2489abd5d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -195,6 +195,8 @@ class Req: return all_ids[self.surr_offset :], self.read_offset - self.surr_offset def get_next_inc_detokenization(self): + if self.tokenizer is None: + return False, "" read_ids, read_offset = self.init_incremental_detokenize() surr_ids = read_ids[:read_offset] @@ -225,16 +227,11 @@ class Req: return last_token_id = self.output_ids[-1] - if ( - last_token_id == self.tokenizer.eos_token_id - and not self.sampling_params.ignore_eos - ): - self.finished_reason = FINISH_MATCHED_TOKEN( - matched=self.tokenizer.eos_token_id - ) - return - - if last_token_id in self.sampling_params.stop_token_ids: + if self.tokenizer is None: + matched_eos = last_token_id in self.sampling_params.stop_token_ids + else: + matched_eos = last_token_id == self.tokenizer.eos_token_id + if matched_eos and not self.sampling_params.ignore_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) return diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 43c70ac7c..e2c825973 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -95,25 +95,28 @@ class TokenizerManager: else: self.context_len = get_context_length(self.hf_config) - if is_multimodal_model(self.model_path): - self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - self.tokenizer = self.processor.tokenizer - os.environ["TOKENIZERS_PARALLELISM"] = "false" - self.executor = concurrent.futures.ProcessPoolExecutor( - initializer=init_global_processor, - mp_context=mp.get_context("fork"), - initargs=(server_args,), - ) + if server_args.skip_tokenizer_init: + self.tokenizer = self.processor = None else: - self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) + if is_multimodal_model(self.model_path): + self.processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + self.tokenizer = self.processor.tokenizer + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self.executor = concurrent.futures.ProcessPoolExecutor( + initializer=init_global_processor, + mp_context=mp.get_context("fork"), + initargs=(server_args,), + ) + else: + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) self.to_create_loop = True self.rid_to_state: Dict[str, ReqState] = {} @@ -171,6 +174,7 @@ class TokenizerManager: rid = obj.rid if not_use_index else obj.rid[index] input_text = obj.text if not_use_index else obj.text[index] if obj.input_ids is None: + assert self.tokenizer is not None input_ids = self.tokenizer.encode(input_text) else: input_ids = obj.input_ids if not_use_index else obj.input_ids[index] @@ -207,7 +211,20 @@ class TokenizerManager: else: input_text = obj.text rid = obj.rid[0] - input_ids = self.tokenizer.encode(input_text) + if self.tokenizer is not None: + input_ids = self.tokenizer.encode(input_text) + else: + assert obj.input_ids is not None + input_ids = obj.input_ids + if isinstance(obj.input_ids, list) and isinstance( + obj.input_ids[0], list + ): + # when obj["input_ids"] is List[List[int]] + input_ids = obj.input_ids[index] + rid = obj.rid[index] + else: + input_ids = obj.input_ids + rid = obj.rid[0] else: input_text = None if isinstance(obj.input_ids, list) and isinstance( @@ -420,7 +437,7 @@ class TokenizerManager: # Log requests if self.server_args.log_requests and state.finished: if obj.text is None: - in_obj = {"text": self.tokenizer.decode(obj.input_ids)} + in_obj = {"input_ids": obj.input_ids} else: in_obj = {"text": obj.text} logger.info(f"in={in_obj}, out={out}") @@ -488,11 +505,12 @@ class TokenizerManager: async def handle_loop(self): while True: - recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = ( + recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = ( await self.recv_from_detokenizer.recv_pyobj() ) - assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut)) - + assert isinstance( + recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) + ), f"Unexpected obj received: {type(recv_obj)}" for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: @@ -504,6 +522,15 @@ class TokenizerManager: "text": recv_obj.output_strs[i], "meta_info": recv_obj.meta_info[i], } + elif isinstance(recv_obj, BatchTokenIDOut): + read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1] + out_dict = { + "token_ids": recv_obj.decode_ids[ + read_start : recv_obj.read_offsets[i] + ], + "meta_info": recv_obj.meta_info[i], + } + else: assert isinstance(recv_obj, BatchEmbeddingOut) out_dict = { @@ -549,6 +576,7 @@ class TokenizerManager: if not decode_to_text: return [(logprob, token_id, None) for logprob, token_id in token_logprobs] + assert self.tokenizer is not None token_ids = [tid for _, tid in token_logprobs] token_texts = self.tokenizer.batch_decode(token_ids) return [ diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index a7a78bde3..a73bddc6d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -100,20 +100,22 @@ class ModelTpServer: nccl_port=nccl_port, server_args=server_args, ) - - if is_multimodal_model(server_args.model_path): - self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - self.tokenizer = self.processor.tokenizer + if server_args.skip_tokenizer_init: + self.tokenizer = self.processor = None else: - self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) + if is_multimodal_model(server_args.model_path): + self.processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + self.tokenizer = self.processor.tokenizer + else: + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_prefill_tokens = ( 16384 @@ -182,13 +184,15 @@ class ModelTpServer: self.last_stats_tic = time.time() # Init the FSM cache for constrained generation - self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - ) + if not server_args.skip_tokenizer_init: + self.regex_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + ) self.jump_forward_cache = JumpForwardCache() # Init new token estimation @@ -466,7 +470,11 @@ class ModelTpServer: next_token_ids = next_token_ids.tolist() else: - next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) + if self.tokenizer is None: + for i, req in enumerate(batch.reqs): + next_token_ids.extend(req.sampling_params.stop_token_ids) + else: + next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) # Check finish conditions pt = 0 diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index 39774d9ac..29067dc85 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -111,13 +111,19 @@ class SamplingParams: # Process stop strings if self.stop_strs is None: self.stop_strs = [] - self.stop_str_max_len = 0 + if self.stop_token_ids is None: + self.stop_str_max_len = 0 + else: + self.stop_str_max_len = 1 else: if isinstance(self.stop_strs, str): self.stop_strs = [self.stop_strs] stop_str_max_len = 0 for stop_str in self.stop_strs: - stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False) - stop_str_max_len = max(stop_str_max_len, len(stop_str_ids)) + if tokenizer is not None: + stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False) + stop_str_max_len = max(stop_str_max_len, len(stop_str_ids)) + else: + stop_str_max_len = max(stop_str_max_len, len(stop_str)) self.stop_str_max_len = stop_str_max_len diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ed611242f..269aed66f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -420,17 +420,22 @@ def _wait_and_warmup(server_args, pipe_finish_writer): # Send a warmup request request_name = "/generate" if model_info["is_generation"] else "/encode" max_new_tokens = 8 if model_info["is_generation"] else 1 + json_data = { + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + } + if server_args.skip_tokenizer_init: + json_data["input_ids"] = [10, 11, 12] + else: + json_data["text"] = "The capital city of France is" + try: for _ in range(server_args.dp_size): res = requests.post( url + request_name, - json={ - "text": "The capital city of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - }, + json=json_data, headers=headers, timeout=600, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f42afdf8d..5cd8373e8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -27,6 +27,7 @@ class ServerArgs: model_path: str tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" + skip_tokenizer_init: bool = False load_format: str = "auto" dtype: str = "auto" trust_remote_code: bool = True @@ -151,6 +152,11 @@ class ServerArgs: "tokenizer if available, and 'slow' will " "always use the slow tokenizer.", ) + parser.add_argument( + "--skip-tokenizer-init", + action="store_true", + help="If set, skip init tokenizer and pass input_ids in generate request", + ) parser.add_argument( "--load-format", type=str, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index dd41156f3..2d20881c8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -197,6 +197,8 @@ def allocate_init_ports( def get_int_token_logit_bias(tokenizer, vocab_size): """Get the logit bias for integer-only tokens.""" # a bug when model's vocab size > tokenizer.vocab_size + if tokenizer == None: + return [-1e5] * vocab_size vocab_size = tokenizer.vocab_size logit_bias = np.zeros(vocab_size, dtype=np.float32) for t_id in range(vocab_size): diff --git a/test/srt/test_skip_tokenizer_srt.py b/test/srt/test_skip_tokenizer_srt.py new file mode 100644 index 000000000..7f0a1fe1a --- /dev/null +++ b/test/srt/test_skip_tokenizer_srt.py @@ -0,0 +1,77 @@ +import json +import os +import sys +import unittest + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server + +# os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + +class TestSRTEndpoint(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = "http://127.0.0.1:8157" + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=300, other_args=["--skip-tokenizer-init"] + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode( + self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1 + ): + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": [ + 119689, + 50650, + 18291, + 30061, + 5316, + 26951, + 119690, + ], # The capital of France is + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 32, + "n": n, + "stop_token_ids": [119690], + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + + def test_simple_decode(self): + self.run_decode() + + def test_parallel_sample(self): + self.run_decode(n=3) + + def test_logprob(self): + for top_logprobs_num in [0, 3]: + for return_text in [False, False]: + self.run_decode( + return_logprob=True, + top_logprobs_num=top_logprobs_num, + return_text=return_text, + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore")