From 322421fae36424cdcef16ecc913e7f6e92d4b7d2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 5 Feb 2024 14:21:16 -0800 Subject: [PATCH] Add warmup to SRT server (#146) --- python/sglang/srt/server.py | 58 +++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 560c93b28..18d67cac7 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -389,7 +389,7 @@ def launch_server(server_args, pipe_finish_writer): assert proc_router.is_alive() and proc_detoken.is_alive() - def launch_server(): + def _launch_server(): # Launch api server uvicorn.run( app, @@ -400,26 +400,48 @@ def launch_server(server_args, pipe_finish_writer): loop="uvloop", ) - t = threading.Thread(target=launch_server) + t = threading.Thread(target=_launch_server) t.start() - if pipe_finish_writer: - url = server_args.url() - - success = False - for i in range(60): - time.sleep(1) - try: - res = requests.get(url + "/get_model_info", timeout=5) - success = True - break - except requests.exceptions.RequestException as e: - pass - - if success: - pipe_finish_writer.send("init ok") - else: + url = server_args.url() + for _ in range(60): + time.sleep(1) + try: + requests.get(url + "/get_model_info", timeout=5) + break + except requests.exceptions.RequestException as e: + pass + else: + if pipe_finish_writer is not None: pipe_finish_writer.send(str(e)) + else: + print(e, flush=True) + return + + # Warmup + try: + print("Warmup...", flush=True) + res = requests.post( + url + "/generate", + json={ + "text": "Say this is a warmup request.", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, + timeout=60, + ) + print(f"Warmup done. model response: {res.json()['text']}") + except requests.exceptions.RequestException as e: + if pipe_finish_writer is not None: + pipe_finish_writer.send(str(e)) + else: + print(e, flush=True) + return + + if pipe_finish_writer is not None: + pipe_finish_writer.send("init ok") class Runtime: