Fix mixed chunked prefill (#1850)
This commit is contained in:
@@ -720,6 +720,8 @@ class Scheduler:
|
|||||||
|
|
||||||
# Mixed-style chunked prefill
|
# Mixed-style chunked prefill
|
||||||
if self.is_mixed_chunk and self.running_batch is not None:
|
if self.is_mixed_chunk and self.running_batch is not None:
|
||||||
|
self.running_batch.filter_batch()
|
||||||
|
if not self.running_batch.is_empty():
|
||||||
self.running_batch.prepare_for_decode(self.enable_overlap)
|
self.running_batch.prepare_for_decode(self.enable_overlap)
|
||||||
new_batch.mix_with_running(self.running_batch)
|
new_batch.mix_with_running(self.running_batch)
|
||||||
new_batch.decoding_reqs = self.running_batch.reqs
|
new_batch.decoding_reqs = self.running_batch.reqs
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import random
|
|||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
@@ -656,11 +657,12 @@ def read_output(output_lines):
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
def run_mmlu_test(
|
def run_and_check_memory_leak(
|
||||||
|
workload_func,
|
||||||
disable_radix_cache,
|
disable_radix_cache,
|
||||||
enable_mixed_chunk=False,
|
enable_mixed_chunk,
|
||||||
enable_overlap=False,
|
enable_overlap,
|
||||||
chunked_prefill_size=32,
|
chunked_prefill_size,
|
||||||
):
|
):
|
||||||
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
||||||
if disable_radix_cache:
|
if disable_radix_cache:
|
||||||
@@ -690,21 +692,8 @@ def run_mmlu_test(
|
|||||||
t = threading.Thread(target=read_output, args=(output_lines,))
|
t = threading.Thread(target=read_output, args=(output_lines,))
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
# Run the eval
|
# Run the workload
|
||||||
args = SimpleNamespace(
|
workload_func(base_url, model)
|
||||||
base_url=base_url,
|
|
||||||
model=model,
|
|
||||||
eval_name="mmlu",
|
|
||||||
num_examples=128,
|
|
||||||
num_threads=128,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
metrics = run_eval(args)
|
|
||||||
print(f"{metrics=}")
|
|
||||||
assert metrics["score"] >= 0.65
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Clean up everything
|
# Clean up everything
|
||||||
kill_child_process(process.pid, include_self=True)
|
kill_child_process(process.pid, include_self=True)
|
||||||
@@ -727,4 +716,63 @@ def run_mmlu_test(
|
|||||||
has_leak = True
|
has_leak = True
|
||||||
|
|
||||||
assert has_new_server
|
assert has_new_server
|
||||||
# assert not has_leak
|
assert not has_leak
|
||||||
|
|
||||||
|
|
||||||
|
def run_mmlu_test(
|
||||||
|
disable_radix_cache=False,
|
||||||
|
enable_mixed_chunk=False,
|
||||||
|
enable_overlap=False,
|
||||||
|
chunked_prefill_size=32,
|
||||||
|
):
|
||||||
|
def workload_func(base_url, model):
|
||||||
|
# Run the eval
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=base_url,
|
||||||
|
model=model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=128,
|
||||||
|
num_threads=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
assert metrics["score"] >= 0.65
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
|
||||||
|
|
||||||
|
|
||||||
|
def run_mulit_request_test(
|
||||||
|
disable_radix_cache=False,
|
||||||
|
enable_mixed_chunk=False,
|
||||||
|
enable_overlap=False,
|
||||||
|
chunked_prefill_size=32,
|
||||||
|
):
|
||||||
|
|
||||||
|
def workload_func(base_url, model):
|
||||||
|
def run_one(_):
|
||||||
|
prompt = """
|
||||||
|
System: You are a helpful assistant.
|
||||||
|
User: What is the capital of France?
|
||||||
|
Assistant: The capital of France is
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ret = response.json()
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(2) as executor:
|
||||||
|
list(executor.map(run_one, list(range(4))))
|
||||||
|
|
||||||
|
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
run_bench_serving,
|
run_bench_serving,
|
||||||
run_mmlu_test,
|
run_mmlu_test,
|
||||||
|
run_mulit_request_test,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -39,6 +40,12 @@ class TestChunkedPrefill(unittest.TestCase):
|
|||||||
|
|
||||||
assert res["completed"] == 10
|
assert res["completed"] == 10
|
||||||
|
|
||||||
|
def test_mixed_chunked_prefill_multi_requests(self):
|
||||||
|
run_mulit_request_test(
|
||||||
|
enable_mixed_chunk=True,
|
||||||
|
chunked_prefill_size=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user