From 1acbaf1b5aed65f8232a689042801c569d6e2661 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Mon, 6 Jan 2025 18:04:55 -0500 Subject: [PATCH] Add generator-style run_batch function (#2513) Co-authored-by: openhands --- python/sglang/lang/interpreter.py | 70 +++++++++++++++++++++++++++++ python/sglang/lang/ir.py | 2 + python/sglang/test/test_programs.py | 24 +++++++++- 3 files changed, 95 insertions(+), 1 deletion(-) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 55a20336b..6d1ca71ad 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -96,6 +96,7 @@ def run_program_batch( default_sampling_para, num_threads, progress_bar, + generator_style=False, ): if hasattr(backend, "endpoint"): backend = backend.endpoint @@ -109,6 +110,17 @@ def run_program_batch( num_threads = max(96, multiprocessing.cpu_count() * 16) num_threads = min(num_threads, len(batch_arguments)) + if generator_style: + return _run_program_batch_generator( + program, + backend, + batch_arguments, + default_sampling_para, + num_threads, + progress_bar, + ) + + # Original code path when generator_style=False if num_threads == 1: rets = [] if progress_bar: @@ -168,6 +180,64 @@ def run_program_batch( return rets +def _run_program_batch_generator( + program, + backend, + batch_arguments, + default_sampling_para, + num_threads, + progress_bar, +): + """Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor.""" + if num_threads == 1: + iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments + for arguments in iterator: + yield run_program( + program, + backend, + (), + arguments, + default_sampling_para, + False, + True, + ) + else: + pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None + + # Process in chunks to avoid overwhelming ThreadPoolExecutor + # Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks + # so we will never reach "yield" until all tasks are done + chunk_size = 200 + + with ThreadPoolExecutor(num_threads) as executor: + for chunk_start in range(0, len(batch_arguments), chunk_size): + chunk_end = min(chunk_start + chunk_size, len(batch_arguments)) + chunk_futures = [] + + # Submit chunk of tasks + for i in range(chunk_start, chunk_end): + future = executor.submit( + run_program, + program, + backend, + (), + batch_arguments[i], + default_sampling_para, + False, + True, + ) + if pbar: + future.add_done_callback(lambda _: pbar.update()) + chunk_futures.append(future) + + # Yield results from this chunk as they complete + for future in chunk_futures: + yield future.result() + + if pbar: + pbar.close() + + def cache_program(program, backend): from sglang.lang.tracer import extract_prefix_by_tracing diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index d3c010108..1ae5ac106 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -227,6 +227,7 @@ class SglFunction: backend=None, num_threads: Union[str, int] = "auto", progress_bar: bool = False, + generator_style: bool = False, ): from sglang.lang.interpreter import run_program_batch @@ -277,6 +278,7 @@ class SglFunction: default_sampling_para, num_threads, progress_bar, + generator_style=generator_style, ) def trace(self, *, backend=None, **kwargs): diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index a251e0aca..411a20b92 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -509,13 +509,35 @@ def test_hellaswag_select(): temperature=0, num_threads=64, progress_bar=True, + generator_style=False, ) - preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))] + preds = [] + for i, ret in enumerate(rets): + preds.append(choices[i].index(ret["answer"])) latency = time.time() - tic # Compute accuracy accuracy = np.mean(np.array(preds) == np.array(labels)) + # Test generator style of run_batch + tic = time.time() + rets = few_shot_hellaswag.run_batch( + arguments, + temperature=0, + num_threads=64, + progress_bar=True, + generator_style=True, + ) + preds_gen = [] + for i, ret in enumerate(rets): + preds_gen.append(choices[i].index(ret["answer"])) + latency_gen = time.time() - tic + + # Compute accuracy + accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) + assert np.abs(accuracy_gen - accuracy) < 0.01 + assert np.abs(latency_gen - latency) < 1 + return accuracy, latency