Fix mixed chunked prefill (#1850)

This commit is contained in:
Lianmin Zheng
2024-10-30 21:20:41 -07:00
committed by GitHub
parent a7a0a6886b
commit f7102fbd2b
3 changed files with 80 additions and 23 deletions

View File

@@ -720,9 +720,11 @@ class Scheduler:
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode(self.enable_overlap) self.running_batch.filter_batch()
new_batch.mix_with_running(self.running_batch) if not self.running_batch.is_empty():
new_batch.decoding_reqs = self.running_batch.reqs self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None self.running_batch = None
else: else:
new_batch.decoding_reqs = None new_batch.decoding_reqs = None

View File

@@ -7,6 +7,7 @@ import random
import subprocess import subprocess
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from types import SimpleNamespace from types import SimpleNamespace
from typing import Callable, List, Optional from typing import Callable, List, Optional
@@ -656,11 +657,12 @@ def read_output(output_lines):
time.sleep(0.1) time.sleep(0.1)
def run_mmlu_test( def run_and_check_memory_leak(
workload_func,
disable_radix_cache, disable_radix_cache,
enable_mixed_chunk=False, enable_mixed_chunk,
enable_overlap=False, enable_overlap,
chunked_prefill_size=32, chunked_prefill_size,
): ):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
if disable_radix_cache: if disable_radix_cache:
@@ -690,21 +692,8 @@ def run_mmlu_test(
t = threading.Thread(target=read_output, args=(output_lines,)) t = threading.Thread(target=read_output, args=(output_lines,))
t.start() t.start()
# Run the eval # Run the workload
args = SimpleNamespace( workload_func(base_url, model)
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)
try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass
# Clean up everything # Clean up everything
kill_child_process(process.pid, include_self=True) kill_child_process(process.pid, include_self=True)
@@ -727,4 +716,63 @@ def run_mmlu_test(
has_leak = True has_leak = True
assert has_new_server assert has_new_server
# assert not has_leak assert not has_leak
def run_mmlu_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):
def workload_func(base_url, model):
# Run the eval
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)
try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
def run_mulit_request_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):
def workload_func(base_url, model):
def run_one(_):
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response = requests.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
)
ret = response.json()
with ThreadPoolExecutor(2) as executor:
list(executor.map(run_one, list(range(4))))
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)

View File

@@ -8,6 +8,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
run_bench_serving, run_bench_serving,
run_mmlu_test, run_mmlu_test,
run_mulit_request_test,
) )
@@ -39,6 +40,12 @@ class TestChunkedPrefill(unittest.TestCase):
assert res["completed"] == 10 assert res["completed"] == 10
def test_mixed_chunked_prefill_multi_requests(self):
run_mulit_request_test(
enable_mixed_chunk=True,
chunked_prefill_size=2048,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()