From 116685337e817e6e328ced94becdeb4979d83f36 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 17 Nov 2024 21:29:30 -0800 Subject: [PATCH] Fix cuda illegal memory access in overlap mode (#2070) --- python/sglang/srt/managers/schedule_batch.py | 3 --- python/sglang/srt/managers/scheduler.py | 3 +++ test/srt/test_srt_engine.py | 15 +++++++-------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6171c93c0..20007d1dc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1055,9 +1055,6 @@ class ScheduleBatch: ) def copy(self): - # We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors. - _ = self.seq_lens[0].item() - # Only contain fields that will be used by process_batch_result return ScheduleBatch( reqs=self.reqs, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2c78e70bf..125abaaf7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -390,6 +390,9 @@ class Scheduler: batch = self.get_next_batch_to_run() self.cur_batch = batch if batch: + # We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors. + _ = batch.seq_lens[0].item() + result = self.run_batch(batch) result_queue.append((batch.copy(), result)) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 33232f50b..988d41ee6 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -16,7 +16,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.server_args import ServerArgs from sglang.test.few_shot_gsm8k_engine import run_eval from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ) @@ -43,7 +42,7 @@ class TestSRTEngine(unittest.TestCase): print("==== Answer 2 ====") print(out2) - assert out1 == out2, f"{out1} != {out2}" + self.assertEqual(out1, out2) def test_2_engine_multiple_generate(self): # just to ensure there is no issue running multiple generate calls @@ -106,14 +105,14 @@ class TestSRTEngine(unittest.TestCase): def test_4_gsm8k(self): args = SimpleNamespace( - model_path=DEFAULT_MODEL_NAME_FOR_TEST, + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, local_data_path=None, num_shots=5, num_questions=200, ) metrics = run_eval(args) - assert metrics["accuracy"] > 0.7 + self.assertGreater(metrics["accuracy"], 0.3) def test_5_prompt_input_ids_consistency(self): prompt = "The capital of UK is" @@ -136,7 +135,7 @@ class TestSRTEngine(unittest.TestCase): print("==== Answer 2 ====") print(out2) - assert out1 == out2, f"{out1} != {out2}" + self.assertEqual(out1, out2) def test_6_engine_runtime_encode_consistency(self): prompt = "Today is a sunny day and I like" @@ -156,11 +155,11 @@ class TestSRTEngine(unittest.TestCase): def test_7_engine_offline_throughput(self): server_args = ServerArgs( - model_path=DEFAULT_MODEL_NAME_FOR_TEST, + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ) - bench_args = BenchArgs(num_prompts=100) + bench_args = BenchArgs(num_prompts=10) result = throughput_test(server_args=server_args, bench_args=bench_args) - self.assertTrue(result["total_throughput"] > 3000) + self.assertGreater(result["total_throughput"], 3500) if __name__ == "__main__":