From e8613df071fb126f97c0d1254977586f39362e08 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 7 Oct 2024 21:26:56 -0700 Subject: [PATCH] [Engine] Fix generate hanging issue after the first call (#1606) --- python/sglang/srt/server.py | 5 +++-- test/srt/test_srt_engine.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index fb1bb8196..c6b2a345b 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -691,8 +691,9 @@ class Engine: lora_path=lora_path, ) - # make it synchronous - return asyncio.run(generate_request(obj, None)) + # get the current event loop + loop = asyncio.get_event_loop() + return loop.run_until_complete(generate_request(obj, None)) def shutdown(self): kill_child_process(os.getpid(), including_parent=False) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index e9e6c9783..d1ecd61fc 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -28,6 +28,18 @@ class TestSRTBackend(unittest.TestCase): print(out2) assert out1 == out2, f"{out1} != {out2}" + def test_engine_multiple_generate(self): + # just to ensure there is no issue running multiple generate calls + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_MODEL_NAME_FOR_TEST + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + engine = sgl.Engine(model_path=model_path, random_seed=42) + engine.generate(prompt, sampling_params) + engine.generate(prompt, sampling_params) + engine.shutdown() + if __name__ == "__main__": unittest.main()