From f7102fbd2b5d1e6bc0373e54b5bead7370dab160 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 30 Oct 2024 21:20:41 -0700 Subject: [PATCH] Fix mixed chunked prefill (#1850) --- python/sglang/srt/managers/scheduler.py | 8 ++- python/sglang/test/test_utils.py | 88 +++++++++++++++++++------ test/srt/test_chunked_prefill.py | 7 ++ 3 files changed, 80 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 47f0b7d44..7c7780a64 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -720,9 +720,11 @@ class Scheduler: # Mixed-style chunked prefill if self.is_mixed_chunk and self.running_batch is not None: - self.running_batch.prepare_for_decode(self.enable_overlap) - new_batch.mix_with_running(self.running_batch) - new_batch.decoding_reqs = self.running_batch.reqs + self.running_batch.filter_batch() + if not self.running_batch.is_empty(): + self.running_batch.prepare_for_decode(self.enable_overlap) + new_batch.mix_with_running(self.running_batch) + new_batch.decoding_reqs = self.running_batch.reqs self.running_batch = None else: new_batch.decoding_reqs = None diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index d6a4c1a29..8a486131f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -7,6 +7,7 @@ import random import subprocess import threading import time +from concurrent.futures import ThreadPoolExecutor from functools import partial from types import SimpleNamespace from typing import Callable, List, Optional @@ -656,11 +657,12 @@ def read_output(output_lines): time.sleep(0.1) -def run_mmlu_test( +def run_and_check_memory_leak( + workload_func, disable_radix_cache, - enable_mixed_chunk=False, - enable_overlap=False, - chunked_prefill_size=32, + enable_mixed_chunk, + enable_overlap, + chunked_prefill_size, ): other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] if disable_radix_cache: @@ -690,21 +692,8 @@ def run_mmlu_test( t = threading.Thread(target=read_output, args=(output_lines,)) t.start() - # Run the eval - args = SimpleNamespace( - base_url=base_url, - model=model, - eval_name="mmlu", - num_examples=128, - num_threads=128, - ) - - try: - metrics = run_eval(args) - print(f"{metrics=}") - assert metrics["score"] >= 0.65 - finally: - pass + # Run the workload + workload_func(base_url, model) # Clean up everything kill_child_process(process.pid, include_self=True) @@ -727,4 +716,63 @@ def run_mmlu_test( has_leak = True assert has_new_server - # assert not has_leak + assert not has_leak + + +def run_mmlu_test( + disable_radix_cache=False, + enable_mixed_chunk=False, + enable_overlap=False, + chunked_prefill_size=32, +): + def workload_func(base_url, model): + # Run the eval + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=128, + num_threads=128, + ) + + try: + metrics = run_eval(args) + print(f"{metrics=}") + assert metrics["score"] >= 0.65 + finally: + pass + + run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size) + + +def run_mulit_request_test( + disable_radix_cache=False, + enable_mixed_chunk=False, + enable_overlap=False, + chunked_prefill_size=32, +): + + def workload_func(base_url, model): + def run_one(_): + prompt = """ + System: You are a helpful assistant. + User: What is the capital of France? + Assistant: The capital of France is + """ + + response = requests.post( + f"{base_url}/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 8, + }, + }, + ) + ret = response.json() + + with ThreadPoolExecutor(2) as executor: + list(executor.map(run_one, list(range(4)))) + + run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size) diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py index 5f618585f..0930603fe 100644 --- a/test/srt/test_chunked_prefill.py +++ b/test/srt/test_chunked_prefill.py @@ -8,6 +8,7 @@ from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, run_bench_serving, run_mmlu_test, + run_mulit_request_test, ) @@ -39,6 +40,12 @@ class TestChunkedPrefill(unittest.TestCase): assert res["completed"] == 10 + def test_mixed_chunked_prefill_multi_requests(self): + run_mulit_request_test( + enable_mixed_chunk=True, + chunked_prefill_size=2048, + ) + if __name__ == "__main__": unittest.main()