From b6aad70ab1160a521151a69202e717dbd652e331 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 6 Oct 2024 20:30:02 -0700 Subject: [PATCH] [Fix] Fix the case where prompt_len = 0 (#1593) --- python/sglang/srt/managers/scheduler.py | 6 ++++-- python/sglang/srt/managers/tokenizer_manager.py | 2 +- python/sglang/srt/utils.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c667020fa..46789568c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -626,9 +626,11 @@ class Scheduler: else: logits_output = None if self.tokenizer is not None: - next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) + next_token_ids = torch.full( + (batch.batch_size(),), self.tokenizer.eos_token_id + ) else: - next_token_ids = [0] * len(batch.reqs) + next_token_ids = torch.full((batch.batch_size(),), 0) return logits_output, next_token_ids else: # embedding or reward model assert batch.extend_num_tokens != 0 diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 27cac65c3..b25290c0a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -526,7 +526,7 @@ class TokenizerManager: async with self.model_update_lock: # wait for the previous generation requests to finish while len(self.rid_to_state) > 0: - await asyncio.sleep(0) + await asyncio.sleep(0.001) self.send_to_scheduler.send_pyobj(obj) self.model_update_result = asyncio.Future() result = await self.model_update_result diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index bc1366b10..c543bb9a2 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -624,6 +624,6 @@ def broadcast_pyobj( tensor_data = torch.empty(size, dtype=torch.uint8) dist.broadcast(tensor_data, src=0, group=dist_group) - serialized_data = bytes(tensor_data.tolist()) + serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) return data