From a47bf39123c4f5bffcf96a80640f234e3f637c4c Mon Sep 17 00:00:00 2001 From: justdoit <24875266+coolhok@users.noreply.github.com> Date: Sat, 11 Jan 2025 06:00:43 +0800 Subject: [PATCH] [Eagle2] Fix multiple concurrent request crashes (#2730) --- python/sglang/srt/speculative/eagle_utils.py | 17 ++- python/sglang/srt/speculative/eagle_worker.py | 2 + test/srt/test_eagle_infer.py | 119 ++++++++++++++++++ 3 files changed, 134 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index b804e7c6a..1a324000c 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -245,9 +245,10 @@ class EAGLEDraftInput(SpecInfo): ) # (b, topk) topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values - selected_input_index = ( - topk_cs_index.flatten() // self.topk - ) # shape: (b * topk) + selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange( + 0, batch.batch_size() * self.topk, step=self.topk, device="cuda" + ).repeat_interleave(self.topk) + batch.spec_info.hidden_states = batch.spec_info.hidden_states[ selected_input_index, : ] @@ -336,6 +337,7 @@ class EAGLEDraftInput(SpecInfo): triton.next_power_of_2(self.spec_steps + 1), ) + batch.seq_lens_sum = sum(batch.seq_lens) batch.input_ids = self.verified_id self.verified_id = new_verified_id @@ -439,7 +441,14 @@ class EAGLEDraftInput(SpecInfo): return kv_indices, cum_kv_seq_len, qo_indptr, None def merge_batch(self, spec_info: EAGLEDraftInput): - + if self.hidden_states is None: + self.hidden_states = spec_info.hidden_states + self.verified_id = spec_info.verified_id + self.sample_output = spec_info.sample_output + self.prev_mode = spec_info.prev_mode + return + if spec_info.hidden_states is None: + return self.hidden_states = torch.cat( [self.hidden_states, spec_info.hidden_states], axis=0 ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 16d54c43b..0e53506a8 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -169,6 +169,8 @@ class EAGLEWorker(TpModelWorker): if not isinstance(reqs, List): reqs = [reqs] for req in reqs: + if req.rid not in self.finish_extend_len: + continue req_len = ( len(req.origin_input_ids) + len(req.output_ids) diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 94ebc79ca..92127b8ef 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -1,8 +1,18 @@ +import multiprocessing +import random +import time import unittest +import requests from transformers import AutoConfig, AutoTokenizer import sglang as sgl +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestEAGLEEngine(unittest.TestCase): @@ -64,5 +74,114 @@ class TestEAGLEEngine(unittest.TestCase): assert tokenizer.eos_token_id not in tokens +prompts = [ + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]" + '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]', + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]", + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwho are you?[/INST]", + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwhere are you from?[/INST]", +] + + +def process(server_url: str): + time.sleep(random.uniform(0, 2)) + for prompt in prompts: + url = server_url + data = { + "model": "base", + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + } + response = requests.post(url, json=data) + assert response.status_code == 200 + + +def abort_process(server_url: str): + for prompt in prompts: + try: + time.sleep(1) + url = server_url + data = { + "model": "base", + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + } + # set timeout = 1s,mock disconnected + requests.post(url, json=data, timeout=1) + except: + pass + + +class TestEAGLELaunchServer(unittest.TestCase): + @classmethod + def setUpClass(cls): + speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B" + cls.model = "meta-llama/Llama-2-7b-chat-hf" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + speculative_draft_model_path, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "16", + "--served-model-name", + "base", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_eagle_server_concurrency(self): + concurrency = 4 + processes = [ + multiprocessing.Process( + target=process, + kwargs={"server_url": self.base_url + "/generate"}, + ) + for _ in range(concurrency) + ] + for worker in processes: + worker.start() + for p in processes: + p.join() + + def test_eagle_server_request_abort(self): + concurrency = 4 + processes = [ + multiprocessing.Process( + target=process, + kwargs={"server_url": self.base_url + "/generate"}, + ) + for _ in range(concurrency) + ] + [ + multiprocessing.Process( + target=abort_process, + kwargs={"server_url": self.base_url + "/generate"}, + ) + for _ in range(concurrency) + ] + for worker in processes: + worker.start() + for p in processes: + p.join() + + if __name__ == "__main__": unittest.main()