From fb99aaa527199de19271668f0aa1e70b780f83fa Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 25 Oct 2024 18:51:59 -0700 Subject: [PATCH] [Fix] Fix --skip-tokenizer-init (#1798) --- .../srt/managers/detokenizer_manager.py | 7 ++-- python/sglang/srt/managers/io_struct.py | 2 ++ python/sglang/srt/managers/scheduler.py | 23 +++++++++--- .../sglang/srt/managers/tokenizer_manager.py | 7 ++-- test/srt/test_skip_tokenizer_init.py | 36 +++++++++++++------ 5 files changed, 50 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index d0d399363..caa5b611e 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -115,12 +115,9 @@ class DetokenizerManager: elif isinstance(recv_obj, GetMemPoolSizeReqOutput): self.send_to_tokenizer.send_pyobj(recv_obj) continue - elif self.tokenizer is None: - # If the tokenizer is skipped, no detokenization is needed - self.send_to_tokenizer.send_pyobj(recv_obj) - continue + else: + assert isinstance(recv_obj, BatchTokenIDOut) - assert isinstance(recv_obj, BatchTokenIDOut) bs = len(recv_obj.rids) # Initialize decode status diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2cdc3f478..f29a7d3bc 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -294,6 +294,8 @@ class BatchTokenIDOut: decoded_texts: List[str] decode_ids: List[int] read_offsets: List[int] + # Only used when `--skip-tokenizer-init` + output_ids: Optional[List[int]] skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] meta_info: List[Dict] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b1fb96b2a..ce5ddd7c7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -104,6 +104,7 @@ class Scheduler: self.lora_paths = server_args.lora_paths self.max_loras_per_batch = server_args.max_loras_per_batch self.enable_overlap = server_args.enable_overlap_schedule + self.skip_tokenizer_init = server_args.skip_tokenizer_init # Init inter-process communication context = zmq.Context(2) @@ -112,8 +113,18 @@ class Scheduler: self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") - self.send_to_detokenizer = context.socket(zmq.PUSH) - self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}") + if server_args.skip_tokenizer_init: + # Directly send to the tokenizer/api + self.send_to_detokenizer = context.socket(zmq.PUSH) + self.send_to_detokenizer.connect( + f"ipc://{port_args.tokenizer_ipc_name}" + ) + else: + # Send to the detokenizer + self.send_to_detokenizer = context.socket(zmq.PUSH) + self.send_to_detokenizer.connect( + f"ipc://{port_args.detokenizer_ipc_name}" + ) else: self.recv_from_tokenizer = None self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) @@ -734,7 +745,7 @@ class Scheduler: ) else: logits_output = None - if self.tokenizer is not None: + if self.skip_tokenizer_init: next_token_ids = torch.full( (batch.batch_size(),), self.tokenizer.eos_token_id ) @@ -950,13 +961,14 @@ class Scheduler: def stream_output(self, reqs: List[Req]): """Stream the output to detokenizer.""" output_rids = [] - output_meta_info = [] + output_meta_info: List[dict] = [] output_finished_reason: List[BaseFinishReason] = [] if self.is_generation: output_vids = [] decoded_texts = [] output_read_ids = [] output_read_offsets = [] + output_ids = [] output_skip_special_tokens = [] output_spaces_between_special_tokens = [] output_no_stop_trim = [] @@ -977,6 +989,8 @@ class Scheduler: read_ids, read_offset = req.init_incremental_detokenize() output_read_ids.append(read_ids) output_read_offsets.append(read_offset) + if self.skip_tokenizer_init: + output_ids.append(req.output_ids) output_skip_special_tokens.append( req.sampling_params.skip_special_tokens ) @@ -1028,6 +1042,7 @@ class Scheduler: decoded_texts, output_read_ids, output_read_offsets, + output_ids, output_skip_special_tokens, output_spaces_between_special_tokens, output_meta_info, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 875239a94..585d5d8ce 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -571,7 +571,7 @@ class TokenizerManager: def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. async def abort_request(): - await asyncio.sleep(3) + await asyncio.sleep(1) if obj.is_single: self.abort_request(obj.rid) else: @@ -621,11 +621,8 @@ class TokenizerManager: "meta_info": recv_obj.meta_info[i], } elif isinstance(recv_obj, BatchTokenIDOut): - read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1] out_dict = { - "token_ids": recv_obj.decode_ids[ - read_start : recv_obj.read_offsets[i] - ], + "token_ids": recv_obj.output_ids[i], "meta_info": recv_obj.meta_info[i], } diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index b159bb557..3a8c34c16 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -29,21 +29,15 @@ class TestSkipTokenizerInit(unittest.TestCase): kill_child_process(cls.process.pid) def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): + max_new_tokens = 32 + input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is response = requests.post( self.base_url + "/generate", json={ - "input_ids": [ - 119689, - 50650, - 18291, - 30061, - 5316, - 26951, - 119690, - ], # The capital of France is + "input_ids": input_ids, "sampling_params": { "temperature": 0 if n == 1 else 0.5, - "max_new_tokens": 32, + "max_new_tokens": max_new_tokens, "n": n, "stop_token_ids": [119690], }, @@ -53,7 +47,27 @@ class TestSkipTokenizerInit(unittest.TestCase): "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) + ret = response.json() + print(json.dumps(ret)) + + def assert_one_item(item): + assert len(item["token_ids"]) == item["meta_info"]["completion_tokens"] + assert len(item["token_ids"]) == max_new_tokens + assert item["meta_info"]["prompt_tokens"] == len(input_ids) + + if return_logprob: + assert len(item["meta_info"]["input_token_logprobs"]) == len( + input_ids + ), f'{len(item["meta_info"]["input_token_logprobs"])} vs. f{len(input_ids)}' + assert len(item["meta_info"]["output_token_logprobs"]) == max_new_tokens + + if n == 1: + assert_one_item(ret) + else: + assert len(ret) == n + for i in range(n): + assert_one_item(ret[i]) + print("=" * 100) def test_simple_decode(self):