diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 9d5f99197..04169e808 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -111,11 +111,14 @@ class PrefillAdder: rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], + mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache - self.rem_total_tokens = rem_total_tokens - self.rem_input_tokens = rem_input_tokens + self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens + self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= mixed_with_decode_tokens self.can_run_list = [] self.new_inflight_req = None diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f6706781d..42c291bb1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -329,6 +329,9 @@ class ScheduleBatch: out_cache_loc: torch.Tensor = None extend_num_tokens: int = None + # For mixed chunekd prefill + prefix_lens_cpu: List[int] = None + # For processing logprobs return_logprob: bool = False top_logprobs_nums: List[int] = None @@ -462,9 +465,33 @@ class ScheduleBatch: self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs] self.batch_sampling_params(vocab_size) + def mix_with_running(self, running_batch: "ScheduleBatch"): + # NOTE: prefix_indices is what has been cached, but we don't cache each decode step + prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs] + prefix_lens_cpu.extend( + [ + len(r.origin_input_ids) + len(r.output_ids) - 1 + for r in running_batch.reqs + ] + ) + + for req in running_batch.reqs: + req.fill_ids = req.origin_input_ids + req.output_ids + req.extend_input_len = 1 + + input_ids = torch.cat([self.input_ids, running_batch.input_ids]) + out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) + extend_num_tokens = self.extend_num_tokens + running_batch.batch_size() + self.merge(running_batch) + self.input_ids = input_ids + self.out_cache_loc = out_cache_loc + self.extend_num_tokens = extend_num_tokens + self.prefix_lens_cpu = prefix_lens_cpu + def check_decode_mem(self): bs = self.batch_size() if self.token_to_kv_pool.available_size() >= bs: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 945a4c95e..b6cfa68bd 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -174,6 +174,9 @@ class ModelTpServer: # Chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size self.current_inflight_req = None + self.is_mixed_chunk = ( + self.chunked_prefill_size is not None and server_args.enable_mixed_chunk + ) # Init the FSM cache for constrained generation if not server_args.skip_tokenizer_init: @@ -366,11 +369,14 @@ class ModelTpServer: # Get priority queue prefix_computed = self.scheduler.calc_priority(self.waiting_queue) + num_mixed_running = running_bs if self.is_mixed_chunk else 0 + adder = PrefillAdder( self.tree_cache, self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, + num_mixed_running, ) if self.running_batch is not None: @@ -416,15 +422,27 @@ class ModelTpServer: ) else: tree_cache_hit_rate = 0.0 - logger.info( - f"[gpu={self.gpu_id}] Prefill batch. " - f"#new-seq: {len(can_run_list)}, " - f"#new-token: {adder.log_input_tokens}, " - f"#cached-token: {adder.log_hit_tokens}, " - f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " - f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" - ) + + if num_mixed_running > 0: + logger.info( + f"[gpu={self.gpu_id}] Prefill batch" + f"(mixed #running-req: {num_mixed_running}). " + f"#new-seq: {len(can_run_list)}, " + f"#new-token: {adder.log_input_tokens}, " + f"#cached-token: {adder.log_hit_tokens}, " + f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " + f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" + ) + else: + logger.info( + f"[gpu={self.gpu_id}] Prefill batch. " + f"#new-seq: {len(can_run_list)}, " + f"#new-token: {adder.log_input_tokens}, " + f"#cached-token: {adder.log_hit_tokens}, " + f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " + f"#running-req: {running_bs}, " + f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" + ) # Return the new batch new_batch = ScheduleBatch.init_new( @@ -440,6 +458,13 @@ class ModelTpServer: # Build batch tensors batch.prepare_for_extend(self.model_config.vocab_size) + decoding_reqs = [] + if self.is_mixed_chunk and self.running_batch is not None: + self.running_batch.prepare_for_decode() + batch.mix_with_running(self.running_batch) + decoding_reqs = self.running_batch.reqs + self.running_batch = None + if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: @@ -481,7 +506,8 @@ class ModelTpServer: if req.finished(): self.tree_cache.cache_finished_req(req) - else: + elif req not in decoding_reqs: + # To reduce overhead, only cache prefill reqs self.tree_cache.cache_unfinished_req(req) if req is self.current_inflight_req: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index ce5ea25ea..3cf68eab2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -88,11 +88,11 @@ class InputMetadata: self.image_sizes = [r.image_size for r in reqs] self.image_offsets = [ ( - (r.image_offset - len(r.prefix_indices)) + (r.image_offset - batch.prefix_lens_cpu[i]) if r.image_offset is not None else 0 ) - for r in reqs + for i, r in enumerate(reqs) ] def compute_positions(self, batch: ScheduleBatch): @@ -109,8 +109,8 @@ class InputMetadata: self.positions = torch.tensor( np.concatenate( [ - np.arange(len(req.prefix_indices), len(req.fill_ids)) - for req in batch.reqs + np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids)) + for i, req in enumerate(batch.reqs) ], axis=0, ), @@ -123,7 +123,7 @@ class InputMetadata: np.concatenate( [ np.arange( - len(req.prefix_indices) + position_ids_offsets_cpu[i], + batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i], len(req.fill_ids) + position_ids_offsets_cpu[i], ) for i, req in enumerate(batch.reqs) @@ -141,12 +141,13 @@ class InputMetadata: self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None else: extend_lens_cpu = [ - len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs + len(r.fill_ids) - batch.prefix_lens_cpu[i] + for i, r in enumerate(batch.reqs) ] self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs) + self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) @classmethod def from_schedule_batch( @@ -180,14 +181,8 @@ class InputMetadata: if forward_mode != ForwardMode.DECODE: ret.init_multimuldal_info(batch) - prefix_lens = None - if forward_mode != ForwardMode.DECODE: - prefix_lens = torch.tensor( - [len(r.prefix_indices) for r in batch.reqs], device="cuda" - ) - if model_runner.server_args.disable_flashinfer: - ret.init_triton_args(batch, prefix_lens) + ret.init_triton_args(batch) flashinfer_use_ragged = False if not model_runner.server_args.disable_flashinfer: @@ -198,30 +193,35 @@ class InputMetadata: ): flashinfer_use_ragged = True ret.init_flashinfer_handlers( - model_runner, prefix_lens, flashinfer_use_ragged + model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged ) return ret - def init_triton_args(self, batch: ScheduleBatch, prefix_lens): + def init_triton_args(self, batch: ScheduleBatch): """Init auxiliary variables for triton attention backend.""" self.triton_max_seq_len = int(torch.max(self.seq_lens)) - self.triton_prefix_lens = prefix_lens self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) if self.forward_mode == ForwardMode.DECODE: self.triton_max_extend_len = None else: - extend_seq_lens = self.seq_lens - prefix_lens + self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + extend_seq_lens = self.seq_lens - self.triton_prefix_lens self.triton_max_extend_len = int(torch.max(extend_seq_lens)) def init_flashinfer_handlers( self, model_runner, - prefix_lens, + prefix_lens_cpu, flashinfer_use_ragged, ): + if self.forward_mode != ForwardMode.DECODE: + prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda") + else: + prefix_lens = None + update_flashinfer_indices( self.forward_mode, model_runner, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4f06f7630..6bbf3050a 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer): print(f"Initialization failed. warmup error: {last_traceback}", flush=True) sys.exit(1) - # Print warnings here - if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None: - logger.warning( - "You set both `--disable-radix-cache` and `--chunked-prefill-size`. " - "This combination is an experimental feature and we noticed it can lead to " - "wrong generation results. If you want to use chunked prefill, it is recommended " - "not using `--disable-radix-cache`." - ) - logger.info("The server is fired up and ready to roll!") if pipe_finish_writer is not None: pipe_finish_writer.send("init ok") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 738ab7d1a..99ecff6a5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -80,6 +80,7 @@ class ServerArgs: disable_regex_jump_forward: bool = False disable_cuda_graph: bool = False disable_disk_cache: bool = False + enable_mixed_chunk: bool = False enable_torch_compile: bool = False enable_p2p_check: bool = False enable_mla: bool = False @@ -396,6 +397,11 @@ class ServerArgs: action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) + parser.add_argument( + "--enable-mixed-chunk", + action="store_true", + help="Enabling mixing prefill and decode in a chunked batch.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py index 4cfd3515f..d97d84de9 100644 --- a/python/sglang/test/simple_eval_common.py +++ b/python/sglang/test/simple_eval_common.py @@ -1,13 +1,12 @@ # Adapted from https://github.com/openai/simple-evals/ -import base64 import os import resource import time from collections import defaultdict from dataclasses import dataclass, field from multiprocessing.pool import ThreadPool -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import httpx import jinja2 @@ -44,8 +43,8 @@ class EvalResult: Result of running an evaluation (usually consisting of many samples) """ - score: float | None # top-line metric - metrics: Dict[str, float] | None # other metrics + score: Optional[float] # top-line metric + metrics: Optional[Dict[str, float]] # other metrics htmls: List[str] # strings of valid HTML convos: List[MessageList] # sampled conversations @@ -56,10 +55,10 @@ class SingleEvalResult: Result of evaluating a single sample """ - score: float | None + score: Optional[float] metrics: Dict[str, float] = field(default_factory=dict) - html: str | None = None - convo: MessageList | None = None # sampled conversation + html: Optional[str] = None + convo: Optional[MessageList] = None # sampled conversation class Eval: @@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase): def __init__( self, base_url: str = None, - model: str | None = None, - system_message: str | None = None, + model: Optional[str] = None, + system_message: Optional[str] = None, temperature: float = 0.0, max_tokens: int = 2048, ): @@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str): def aggregate_results( single_eval_results: List[SingleEvalResult], default_stats: Tuple[str] = ("mean", "std"), - name2stats: Dict[str, Tuple[str]] | None = None, + name2stats: Optional[Dict[str, Tuple[str]]] = None, ) -> EvalResult: """ Aggregate results from multiple evaluations into a single EvalResult. diff --git a/python/sglang/test/simple_eval_gpqa.py b/python/sglang/test/simple_eval_gpqa.py index 46055caa5..ec2abb4ad 100644 --- a/python/sglang/test/simple_eval_gpqa.py +++ b/python/sglang/test/simple_eval_gpqa.py @@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022 import random import re +from typing import Optional import pandas @@ -28,7 +29,7 @@ class GPQAEval(Eval): def __init__( self, filename: str, - num_examples: int | None, + num_examples: Optional[int], num_threads: int, n_repeats: int = 1, ): diff --git a/python/sglang/test/simple_eval_humaneval.py b/python/sglang/test/simple_eval_humaneval.py index efb0d0bd6..b0ad79d41 100644 --- a/python/sglang/test/simple_eval_humaneval.py +++ b/python/sglang/test/simple_eval_humaneval.py @@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/ import random import re from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List +from typing import Dict, List, Optional import tqdm @@ -61,7 +61,7 @@ def evaluate_functional_correctness( class HumanEval(Eval): def __init__( self, - num_examples: int | None, + num_examples: Optional[int], num_threads: int, num_samples_per_task: int = 5, ks_passes: List[int] = [1, 2, 5], diff --git a/python/sglang/test/simple_eval_math.py b/python/sglang/test/simple_eval_math.py index 4ddb650d9..74c49abe5 100644 --- a/python/sglang/test/simple_eval_math.py +++ b/python/sglang/test/simple_eval_math.py @@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874 import random import re +from typing import Optional import pandas @@ -36,7 +37,7 @@ class MathEval(Eval): self, filename: str, equality_checker: SamplerBase, - num_examples: int | None, + num_examples: Optional[int], num_threads: int, ): df = pandas.read_csv(filename) diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py index 3c0287510..36a5c7fe3 100644 --- a/python/sglang/test/simple_eval_mmlu.py +++ b/python/sglang/test/simple_eval_mmlu.py @@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300 import random import re +from typing import Optional import pandas @@ -84,7 +85,7 @@ subject2category = { class MMLUEval(Eval): - def __init__(self, filename: str, num_examples: int | None, num_threads: int): + def __init__(self, filename: str, num_examples: Optional[int], num_threads: int): df = pandas.read_csv(filename) examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py index 94c424762..8d81dc0c3 100644 --- a/test/srt/test_chunked_prefill.py +++ b/test/srt/test_chunked_prefill.py @@ -11,11 +11,14 @@ from sglang.test.test_utils import ( class TestChunkedPrefill(unittest.TestCase): - def run_mmlu(self, disable_radix_cache): + def run_mmlu(self, disable_radix_cache, enable_mixed_chunk): other_args = ["--chunked-prefill-size", "32"] if disable_radix_cache: other_args += ["--disable-radix-cache"] + if enable_mixed_chunk: + other_args += ["--enable-mixed-chunk"] + model = DEFAULT_MODEL_NAME_FOR_TEST base_url = DEFAULT_URL_FOR_UNIT_TEST process = popen_launch_server( @@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase): kill_child_process(process.pid) def test_chunked_prefill(self): - self.run_mmlu(disable_radix_cache=False) + self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=False) + + def test_mixed_chunked_prefill(self): + self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=True) def test_chunked_prefill_without_radix_cache(self): - self.run_mmlu(disable_radix_cache=True) + self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=False) + + def test_mixed_chunked_prefill_without_radix_cache(self): + self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True) if __name__ == "__main__": diff --git a/test/srt/test_eval_accuracy_large_chunked_prefill.py b/test/srt/test_eval_accuracy_large_chunked_prefill.py index 040a2db75..bf4d071b8 100644 --- a/test/srt/test_eval_accuracy_large_chunked_prefill.py +++ b/test/srt/test_eval_accuracy_large_chunked_prefill.py @@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_ACCURACY_TEST, - DEFAULT_URL_FOR_UNIT_TEST, popen_launch_server, ) diff --git a/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py new file mode 100644 index 000000000..b4d7602c4 --- /dev/null +++ b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py @@ -0,0 +1,73 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_ACCURACY_TEST, + popen_launch_server, +) + + +class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_ACCURACY_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=[ + "--log-level-http", + "warning", + "--chunked-prefill-size", + "256", + "--enable-mixed-chunk", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=3000, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.71, f"{metrics}" + + def test_human_eval(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="humaneval", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.64, f"{metrics}" + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.84, f"{metrics}" + + +if __name__ == "__main__": + unittest.main()