diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 0141d2113..5d78b97ce 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -86,16 +86,15 @@ class TpModelWorkerClient: @torch.inference_mode() def forward_thread_func_(self): while True: - tic1 = time.time() model_worker_batch, future_token_ids_ct = self.input_queue.get() # Resolve future tokens in the input - tic2 = time.time() - resolved_input_ids = model_worker_batch.input_ids - future_mask = resolved_input_ids < 0 - resolved_input_ids[future_mask] = self.future_token_ids_map[ - -resolved_input_ids[future_mask] - ] + input_ids = model_worker_batch.input_ids + input_ids[:] = torch.where( + input_ids < 0, + self.future_token_ids_map[torch.clamp(-input_ids, min=0)], + input_ids, + ) # Run forward logits_output, next_token_ids = self.worker.forward_batch_generation( @@ -119,15 +118,6 @@ class TpModelWorkerClient: assert logits_output.next_token_logprobs is None, "Not supported" self.output_queue.put((None, next_token_ids)) - if False: - tic3 = time.time() - self.acc_time_with_waiting += tic3 - tic1 - self.acc_time_without_waiting += tic3 - tic2 - if self.forward_queue.qsize() == 0: - logger.info( - f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}" - ) - def resulve_batch_result(self, bid: int): logits_output, next_token_ids = self.output_queue.get() return logits_output, next_token_ids diff --git a/test/killall_sglang.sh b/test/killall_sglang.sh index b26d86b6f..71dc6b12e 100644 --- a/test/killall_sglang.sh +++ b/test/killall_sglang.sh @@ -1 +1,2 @@ kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') +kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')