Fix mixed chunked prefill (#1850)
This commit is contained in:
@@ -720,9 +720,11 @@ class Scheduler:
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
self.running_batch.prepare_for_decode(self.enable_overlap)
|
||||
new_batch.mix_with_running(self.running_batch)
|
||||
new_batch.decoding_reqs = self.running_batch.reqs
|
||||
self.running_batch.filter_batch()
|
||||
if not self.running_batch.is_empty():
|
||||
self.running_batch.prepare_for_decode(self.enable_overlap)
|
||||
new_batch.mix_with_running(self.running_batch)
|
||||
new_batch.decoding_reqs = self.running_batch.reqs
|
||||
self.running_batch = None
|
||||
else:
|
||||
new_batch.decoding_reqs = None
|
||||
|
||||
@@ -7,6 +7,7 @@ import random
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, List, Optional
|
||||
@@ -656,11 +657,12 @@ def read_output(output_lines):
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def run_mmlu_test(
|
||||
def run_and_check_memory_leak(
|
||||
workload_func,
|
||||
disable_radix_cache,
|
||||
enable_mixed_chunk=False,
|
||||
enable_overlap=False,
|
||||
chunked_prefill_size=32,
|
||||
enable_mixed_chunk,
|
||||
enable_overlap,
|
||||
chunked_prefill_size,
|
||||
):
|
||||
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
||||
if disable_radix_cache:
|
||||
@@ -690,21 +692,8 @@ def run_mmlu_test(
|
||||
t = threading.Thread(target=read_output, args=(output_lines,))
|
||||
t.start()
|
||||
|
||||
# Run the eval
|
||||
args = SimpleNamespace(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
eval_name="mmlu",
|
||||
num_examples=128,
|
||||
num_threads=128,
|
||||
)
|
||||
|
||||
try:
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
assert metrics["score"] >= 0.65
|
||||
finally:
|
||||
pass
|
||||
# Run the workload
|
||||
workload_func(base_url, model)
|
||||
|
||||
# Clean up everything
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
@@ -727,4 +716,63 @@ def run_mmlu_test(
|
||||
has_leak = True
|
||||
|
||||
assert has_new_server
|
||||
# assert not has_leak
|
||||
assert not has_leak
|
||||
|
||||
|
||||
def run_mmlu_test(
|
||||
disable_radix_cache=False,
|
||||
enable_mixed_chunk=False,
|
||||
enable_overlap=False,
|
||||
chunked_prefill_size=32,
|
||||
):
|
||||
def workload_func(base_url, model):
|
||||
# Run the eval
|
||||
args = SimpleNamespace(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
eval_name="mmlu",
|
||||
num_examples=128,
|
||||
num_threads=128,
|
||||
)
|
||||
|
||||
try:
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
assert metrics["score"] >= 0.65
|
||||
finally:
|
||||
pass
|
||||
|
||||
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
|
||||
|
||||
|
||||
def run_mulit_request_test(
|
||||
disable_radix_cache=False,
|
||||
enable_mixed_chunk=False,
|
||||
enable_overlap=False,
|
||||
chunked_prefill_size=32,
|
||||
):
|
||||
|
||||
def workload_func(base_url, model):
|
||||
def run_one(_):
|
||||
prompt = """
|
||||
System: You are a helpful assistant.
|
||||
User: What is the capital of France?
|
||||
Assistant: The capital of France is
|
||||
"""
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
},
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
|
||||
with ThreadPoolExecutor(2) as executor:
|
||||
list(executor.map(run_one, list(range(4))))
|
||||
|
||||
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
|
||||
|
||||
Reference in New Issue
Block a user