Add generator-style run_batch function (#2513)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -96,6 +96,7 @@ def run_program_batch(
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
generator_style=False,
|
||||
):
|
||||
if hasattr(backend, "endpoint"):
|
||||
backend = backend.endpoint
|
||||
@@ -109,6 +110,17 @@ def run_program_batch(
|
||||
num_threads = max(96, multiprocessing.cpu_count() * 16)
|
||||
num_threads = min(num_threads, len(batch_arguments))
|
||||
|
||||
if generator_style:
|
||||
return _run_program_batch_generator(
|
||||
program,
|
||||
backend,
|
||||
batch_arguments,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
# Original code path when generator_style=False
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
if progress_bar:
|
||||
@@ -168,6 +180,64 @@ def run_program_batch(
|
||||
return rets
|
||||
|
||||
|
||||
def _run_program_batch_generator(
|
||||
program,
|
||||
backend,
|
||||
batch_arguments,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
):
|
||||
"""Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
|
||||
if num_threads == 1:
|
||||
iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
|
||||
for arguments in iterator:
|
||||
yield run_program(
|
||||
program,
|
||||
backend,
|
||||
(),
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None
|
||||
|
||||
# Process in chunks to avoid overwhelming ThreadPoolExecutor
|
||||
# Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
|
||||
# so we will never reach "yield" until all tasks are done
|
||||
chunk_size = 200
|
||||
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
for chunk_start in range(0, len(batch_arguments), chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
|
||||
chunk_futures = []
|
||||
|
||||
# Submit chunk of tasks
|
||||
for i in range(chunk_start, chunk_end):
|
||||
future = executor.submit(
|
||||
run_program,
|
||||
program,
|
||||
backend,
|
||||
(),
|
||||
batch_arguments[i],
|
||||
default_sampling_para,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
if pbar:
|
||||
future.add_done_callback(lambda _: pbar.update())
|
||||
chunk_futures.append(future)
|
||||
|
||||
# Yield results from this chunk as they complete
|
||||
for future in chunk_futures:
|
||||
yield future.result()
|
||||
|
||||
if pbar:
|
||||
pbar.close()
|
||||
|
||||
|
||||
def cache_program(program, backend):
|
||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||
|
||||
|
||||
@@ -227,6 +227,7 @@ class SglFunction:
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
generator_style: bool = False,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program_batch
|
||||
|
||||
@@ -277,6 +278,7 @@ class SglFunction:
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
generator_style=generator_style,
|
||||
)
|
||||
|
||||
def trace(self, *, backend=None, **kwargs):
|
||||
|
||||
@@ -509,13 +509,35 @@ def test_hellaswag_select():
|
||||
temperature=0,
|
||||
num_threads=64,
|
||||
progress_bar=True,
|
||||
generator_style=False,
|
||||
)
|
||||
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
||||
preds = []
|
||||
for i, ret in enumerate(rets):
|
||||
preds.append(choices[i].index(ret["answer"]))
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
|
||||
# Test generator style of run_batch
|
||||
tic = time.time()
|
||||
rets = few_shot_hellaswag.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=64,
|
||||
progress_bar=True,
|
||||
generator_style=True,
|
||||
)
|
||||
preds_gen = []
|
||||
for i, ret in enumerate(rets):
|
||||
preds_gen.append(choices[i].index(ret["answer"]))
|
||||
latency_gen = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.01
|
||||
assert np.abs(latency_gen - latency) < 1
|
||||
|
||||
return accuracy, latency
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user