diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ce5fcc8d0..b6a9be71b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1063,7 +1063,7 @@ class ScheduleBatch: out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, - sampling_info=dataclasses.replace(self.sampling_info), + sampling_info=self.sampling_info, ) def __str__(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7012ddf63..1df8499af 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -387,9 +387,6 @@ class Scheduler: batch = self.get_next_batch_to_run() self.cur_batch = batch if batch: - # We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors. - _ = batch.seq_lens[0].item() - result = self.run_batch(batch) result_queue.append((batch.copy(), result)) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 6e5bce36a..69157f2a8 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -142,12 +142,12 @@ class TpModelWorker: def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, - launch_event: Optional[threading.Event] = None, + launch_done: Optional[threading.Event] = None, ): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) - if launch_event: - launch_event.set() + if launch_done: + launch_done.set() next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) return logits_output, next_token_ids diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 6b42d3974..805a687f7 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -96,19 +96,22 @@ class TpModelWorkerClient: @torch.no_grad() def forward_thread_func_(self): while True: - model_worker_batch, future_token_ids_ct = self.input_queue.get() + model_worker_batch, future_token_ids_ct, compute_info_done = ( + self.input_queue.get() + ) if not model_worker_batch: break - self.launch_event = threading.Event() - copy_event = torch.cuda.Event() + self.launch_done = threading.Event() + copy_done = torch.cuda.Event() # Resolve future tokens in the input input_ids = model_worker_batch.input_ids resolve_future_token_ids(input_ids, self.future_token_ids_map) # Run forward + compute_info_done.wait() logits_output, next_token_ids = self.worker.forward_batch_generation( - model_worker_batch, self.launch_event + model_worker_batch, self.launch_done ) # Update the future token ids map @@ -133,15 +136,14 @@ class TpModelWorkerClient: ) ) next_token_ids = next_token_ids.to("cpu", non_blocking=True) - copy_event.record() + copy_done.record() - self.output_queue.put((copy_event, logits_output, next_token_ids)) + self.output_queue.put((copy_done, logits_output, next_token_ids)) def resolve_batch_result(self, bid: int): - copy_event, logits_output, next_token_ids = self.output_queue.get() - while not copy_event.query(): - time.sleep(1e-5) - self.launch_event.wait() + copy_done, logits_output, next_token_ids = self.output_queue.get() + copy_done.synchronize() + self.launch_done.wait() if logits_output.next_token_logprobs is not None: logits_output.next_token_logprobs = ( @@ -162,7 +164,11 @@ class TpModelWorkerClient: model_worker_batch.sampling_info = dataclasses.replace( model_worker_batch.sampling_info ) - self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) + compute_info_done = torch.cuda.Event() + compute_info_done.record() + self.input_queue.put( + (model_worker_batch, self.future_token_ids_ct, compute_info_done) + ) # Allocate output future objects bs = len(model_worker_batch.seq_lens) diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py index 6ebe5e0d9..5ed2b06fc 100644 --- a/test/srt/test_large_max_new_tokens.py +++ b/test/srt/test_large_max_new_tokens.py @@ -38,7 +38,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): api_key=cls.api_key, other_args=( "--max-total-token", - "1024", + "1536", "--context-len", "8192", "--decode-log-interval", diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 988d41ee6..5d7f95440 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -29,7 +29,7 @@ class TestSRTEngine(unittest.TestCase): sampling_params = {"temperature": 0, "max_new_tokens": 8} - engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") + engine = sgl.Engine(model_path=model_path, random_seed=42) out1 = engine.generate(prompt, sampling_params)["text"] engine.shutdown() @@ -51,7 +51,7 @@ class TestSRTEngine(unittest.TestCase): sampling_params = {"temperature": 0, "max_new_tokens": 8} - engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") + engine = sgl.Engine(model_path=model_path, random_seed=42) engine.generate(prompt, sampling_params) engine.generate(prompt, sampling_params) engine.shutdown() @@ -74,7 +74,6 @@ class TestSRTEngine(unittest.TestCase): # Create an LLM. llm = sgl.Engine( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - log_level="error", ) # 1. sync + non streaming @@ -118,7 +117,9 @@ class TestSRTEngine(unittest.TestCase): prompt = "The capital of UK is" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") + engine = sgl.Engine( + model_path=model_path, random_seed=42, disable_radix_cache=True + ) sampling_params = {"temperature": 0, "max_new_tokens": 8} out1 = engine.generate(prompt, sampling_params)["text"] @@ -141,9 +142,7 @@ class TestSRTEngine(unittest.TestCase): prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST - engine = sgl.Engine( - model_path=model_path, is_embedding=True, random_seed=42, log_level="error" - ) + engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) out1 = torch.tensor(engine.encode(prompt)["embedding"]) engine.shutdown()