Add generator-style run_batch function (#2513)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -96,6 +96,7 @@ def run_program_batch(
|
|||||||
default_sampling_para,
|
default_sampling_para,
|
||||||
num_threads,
|
num_threads,
|
||||||
progress_bar,
|
progress_bar,
|
||||||
|
generator_style=False,
|
||||||
):
|
):
|
||||||
if hasattr(backend, "endpoint"):
|
if hasattr(backend, "endpoint"):
|
||||||
backend = backend.endpoint
|
backend = backend.endpoint
|
||||||
@@ -109,6 +110,17 @@ def run_program_batch(
|
|||||||
num_threads = max(96, multiprocessing.cpu_count() * 16)
|
num_threads = max(96, multiprocessing.cpu_count() * 16)
|
||||||
num_threads = min(num_threads, len(batch_arguments))
|
num_threads = min(num_threads, len(batch_arguments))
|
||||||
|
|
||||||
|
if generator_style:
|
||||||
|
return _run_program_batch_generator(
|
||||||
|
program,
|
||||||
|
backend,
|
||||||
|
batch_arguments,
|
||||||
|
default_sampling_para,
|
||||||
|
num_threads,
|
||||||
|
progress_bar,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Original code path when generator_style=False
|
||||||
if num_threads == 1:
|
if num_threads == 1:
|
||||||
rets = []
|
rets = []
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
@@ -168,6 +180,64 @@ def run_program_batch(
|
|||||||
return rets
|
return rets
|
||||||
|
|
||||||
|
|
||||||
|
def _run_program_batch_generator(
|
||||||
|
program,
|
||||||
|
backend,
|
||||||
|
batch_arguments,
|
||||||
|
default_sampling_para,
|
||||||
|
num_threads,
|
||||||
|
progress_bar,
|
||||||
|
):
|
||||||
|
"""Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
|
||||||
|
if num_threads == 1:
|
||||||
|
iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
|
||||||
|
for arguments in iterator:
|
||||||
|
yield run_program(
|
||||||
|
program,
|
||||||
|
backend,
|
||||||
|
(),
|
||||||
|
arguments,
|
||||||
|
default_sampling_para,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None
|
||||||
|
|
||||||
|
# Process in chunks to avoid overwhelming ThreadPoolExecutor
|
||||||
|
# Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
|
||||||
|
# so we will never reach "yield" until all tasks are done
|
||||||
|
chunk_size = 200
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(num_threads) as executor:
|
||||||
|
for chunk_start in range(0, len(batch_arguments), chunk_size):
|
||||||
|
chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
|
||||||
|
chunk_futures = []
|
||||||
|
|
||||||
|
# Submit chunk of tasks
|
||||||
|
for i in range(chunk_start, chunk_end):
|
||||||
|
future = executor.submit(
|
||||||
|
run_program,
|
||||||
|
program,
|
||||||
|
backend,
|
||||||
|
(),
|
||||||
|
batch_arguments[i],
|
||||||
|
default_sampling_para,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
if pbar:
|
||||||
|
future.add_done_callback(lambda _: pbar.update())
|
||||||
|
chunk_futures.append(future)
|
||||||
|
|
||||||
|
# Yield results from this chunk as they complete
|
||||||
|
for future in chunk_futures:
|
||||||
|
yield future.result()
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
def cache_program(program, backend):
|
def cache_program(program, backend):
|
||||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||||
|
|
||||||
|
|||||||
@@ -227,6 +227,7 @@ class SglFunction:
|
|||||||
backend=None,
|
backend=None,
|
||||||
num_threads: Union[str, int] = "auto",
|
num_threads: Union[str, int] = "auto",
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
|
generator_style: bool = False,
|
||||||
):
|
):
|
||||||
from sglang.lang.interpreter import run_program_batch
|
from sglang.lang.interpreter import run_program_batch
|
||||||
|
|
||||||
@@ -277,6 +278,7 @@ class SglFunction:
|
|||||||
default_sampling_para,
|
default_sampling_para,
|
||||||
num_threads,
|
num_threads,
|
||||||
progress_bar,
|
progress_bar,
|
||||||
|
generator_style=generator_style,
|
||||||
)
|
)
|
||||||
|
|
||||||
def trace(self, *, backend=None, **kwargs):
|
def trace(self, *, backend=None, **kwargs):
|
||||||
|
|||||||
@@ -509,13 +509,35 @@ def test_hellaswag_select():
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
num_threads=64,
|
num_threads=64,
|
||||||
progress_bar=True,
|
progress_bar=True,
|
||||||
|
generator_style=False,
|
||||||
)
|
)
|
||||||
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
preds = []
|
||||||
|
for i, ret in enumerate(rets):
|
||||||
|
preds.append(choices[i].index(ret["answer"]))
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||||
|
|
||||||
|
# Test generator style of run_batch
|
||||||
|
tic = time.time()
|
||||||
|
rets = few_shot_hellaswag.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
num_threads=64,
|
||||||
|
progress_bar=True,
|
||||||
|
generator_style=True,
|
||||||
|
)
|
||||||
|
preds_gen = []
|
||||||
|
for i, ret in enumerate(rets):
|
||||||
|
preds_gen.append(choices[i].index(ret["answer"]))
|
||||||
|
latency_gen = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
||||||
|
assert np.abs(accuracy_gen - accuracy) < 0.01
|
||||||
|
assert np.abs(latency_gen - latency) < 1
|
||||||
|
|
||||||
return accuracy, latency
|
return accuracy, latency
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user