From 7d671e4ad2977d8090f44be5e94f351a15f4c9bf Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 19 Nov 2024 22:07:58 -0800 Subject: [PATCH] Enable overlap by default (#2067) --- python/sglang/bench_latency.py | 22 +++----------- .../srt/constrained/outlines_backend.py | 5 +++- python/sglang/srt/managers/schedule_batch.py | 5 +--- python/sglang/srt/managers/scheduler.py | 24 +++++++++++++-- .../srt/managers/tp_worker_overlap_thread.py | 17 +++++++---- .../sglang/srt/model_executor/model_runner.py | 4 +-- .../srt/sampling/sampling_batch_info.py | 9 ++---- python/sglang/srt/server_args.py | 30 ++++++++++++------- python/sglang/test/test_utils.py | 10 +++---- test/srt/run_suite.py | 2 +- test/srt/test_bench_serving.py | 4 +-- test/srt/test_json_constrained.py | 9 +++--- test/srt/test_moe_eval_accuracy_large.py | 2 +- ...edule.py => test_non_overlap_scheduler.py} | 8 ++--- test/srt/test_radix_attention.py | 4 +-- test/srt/test_torch_compile.py | 7 +++-- test/srt/test_torch_compile_moe.py | 5 ++-- 17 files changed, 92 insertions(+), 75 deletions(-) rename test/srt/{test_overlap_schedule.py => test_non_overlap_scheduler.py} (79%) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 28962eb9f..13bc113e6 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): return reqs -def _extend(reqs, model_runner): +@torch.no_grad +def extend(reqs, model_runner): batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, @@ -236,15 +237,8 @@ def _extend(reqs, model_runner): return next_token_ids, logits_output.next_token_logits, batch -def extend(reqs, model_runner): - # Disable inference mode for now when torch TP is applied. We can remove - # this workaround once DTensor adds support for inference mode. - use_inf_mode = not model_runner.torch_tp_applied - with torch.inference_mode(use_inf_mode): - return _extend(reqs, model_runner) - - -def _decode(input_token_ids, batch, model_runner): +@torch.no_grad +def decode(input_token_ids, batch, model_runner): batch.output_ids = input_token_ids batch.prepare_for_decode() model_worker_batch = batch.get_model_worker_batch() @@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner): return next_token_ids, logits_output.next_token_logits -def decode(input_token_ids, batch, model_runner): - # Disable inference mode for now when torch TP is applied. We can remove - # this workaround once DTensor adds support for inference mode. - use_inf_mode = not model_runner.torch_tp_applied - with torch.inference_mode(use_inf_mode): - return _decode(input_token_ids, batch, model_runner) - - def correctness_test( server_args, port_args, diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 831c1d1a9..801c4457f 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject): return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + tokens = torch.tensor( + self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 + ).to(vocab_mask.device, non_blocking=True) vocab_mask = vocab_mask[idx] vocab_mask.fill_(1) - vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0 + vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool)) @staticmethod def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 615301154..ca08a6af3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -899,10 +899,7 @@ class ScheduleBatch: self.input_ids = self.output_ids self.output_ids = None - if self.sampling_info.penalizer_orchestrator: - self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - self.input_ids - ) + self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids) # Alloc mem bs = len(self.reqs) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e555c0d94..a411e7af7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -30,7 +30,7 @@ import torch import zmq from sglang.global_config import global_config -from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -102,7 +102,7 @@ class Scheduler: self.disable_jump_forward = server_args.disable_jump_forward self.lora_paths = server_args.lora_paths self.max_loras_per_batch = server_args.max_loras_per_batch - self.enable_overlap = server_args.enable_overlap_schedule + self.enable_overlap = not server_args.disable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics @@ -159,6 +159,23 @@ class Scheduler: trust_remote_code=server_args.trust_remote_code, ) + # Check whether overlap can be enabled + if not self.is_generation: + self.enable_overlap = False + logger.info("Overlap scheduler is disabled for embedding models.") + if ( + server_args.attention_backend == "triton" + or server_args.enable_double_sparsity + or ( + self.model_config.attention_arch == AttentionArch.MLA + and not self.server_args.disable_mla + ) + ): + self.enable_overlap = False + logger.info( + "Overlap scheduler is disabled if using triton attention backend." + ) + # Launch a tensor parallel worker if self.enable_overlap: TpWorkerClass = TpModelWorkerClient @@ -903,6 +920,7 @@ class Scheduler: self.process_batch_result_prefill(batch, result) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() + torch.cuda.current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() def process_batch_result_prefill(self, batch: ScheduleBatch, result): @@ -958,6 +976,7 @@ class Scheduler: if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() + torch.cuda.current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model @@ -1031,6 +1050,7 @@ class Scheduler: if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() + torch.cuda.current_stream().synchronize() batch.next_batch_sampling_info.sampling_info_done.set() self.stream_output(batch.reqs) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 253900f35..5435e4bf9 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -157,14 +157,19 @@ class TpModelWorkerClient: def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): # A cuda stream sync here to avoid the cuda illegal memory access error. - _ = model_worker_batch.seq_lens[0].item() + torch.cuda.current_stream().synchronize() + + # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. + sampling_info = model_worker_batch.sampling_info + sampling_info.update_penalties() + model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace( + sampling_info, + sampling_info_done=threading.Event(), + scaling_penalties=sampling_info.scaling_penalties, + linear_penalties=sampling_info.linear_penalties, + ) # Push a new batch to the queue - model_worker_batch.sampling_info = dataclasses.replace( - model_worker_batch.sampling_info, - sampling_info_done=threading.Event(), - ) - self.cur_sampling_info = model_worker_batch.sampling_info self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) # Allocate output future objects diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index efd4fc214..3144efe84 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -116,7 +116,7 @@ class ModelRunner: ) if self.is_multimodal: - logger.warning( + logger.info( "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." ) server_args.chunked_prefill_size = None @@ -636,13 +636,11 @@ class ModelRunner: self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ) -> torch.Tensor: sampling_info = forward_batch.sampling_info - if sampling_info.sampling_info_done: # Overlap mode: the function update_regex_vocab_mask was executed # in process_batch_result of the last batch. if sampling_info.grammars: sampling_info.sampling_info_done.wait() - sampling_info.update_penalties() else: # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info.update_regex_vocab_mask() diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6be15e6ac..3948ed069 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -132,9 +132,6 @@ class SamplingBatchInfo: return len(self.temperatures) def update_penalties(self): - if not self.penalizer_orchestrator: - return - self.scaling_penalties = None self.linear_penalties = None @@ -176,8 +173,7 @@ class SamplingBatchInfo: grammar.fill_vocab_mask(self.vocab_mask, i) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): - if self.penalizer_orchestrator: - self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + self.penalizer_orchestrator.filter(unfinished_indices, new_indices) for item in [ "temperatures", @@ -216,8 +212,7 @@ class SamplingBatchInfo: return None def merge_batch(self, other: "SamplingBatchInfo"): - if self.penalizer_orchestrator: - self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + self.penalizer_orchestrator.merge(other.penalizer_orchestrator) for item in [ "temperatures", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index de1d4ee68..e1cbbd29f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -123,7 +123,7 @@ class ServerArgs: disable_disk_cache: bool = False disable_custom_all_reduce: bool = False disable_mla: bool = False - enable_overlap_schedule: bool = False + disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False enable_torch_compile: bool = False @@ -172,9 +172,7 @@ class ServerArgs: if gpu_mem < 25000: self.chunked_prefill_size //= 4 # make it 2048 self.cuda_graph_max_bs = 4 - logger.warning( - "Automatically adjust --chunked-prefill-size for small GPUs." - ) + logger.info("Automatically adjust --chunked-prefill-size for small GPUs.") if not is_flashinfer_available(): self.attention_backend = "triton" @@ -192,15 +190,22 @@ class ServerArgs: self.chunked_prefill_size = self.chunked_prefill_size // 2 self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.schedule_conservativeness = self.schedule_conservativeness * 0.3 - self.enable_overlap_schedule = False - logger.warning( + self.disable_overlap_schedule = True + logger.info( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " - "Data parallel size is adjusted to be the same as tensor parallel size." + "Data parallel size is adjusted to be the same as tensor parallel size. " + "Overlap schedule is disabled." ) - if self.enable_overlap_schedule: + if self.enable_mixed_chunk: + logger.info( + "Overlap schedule is disabled because mixed-style chunked prefill is enabled." + ) + self.disable_overlap_schedule = True + + if not self.disable_overlap_schedule: self.disable_jump_forward = True @staticmethod @@ -624,9 +629,9 @@ class ServerArgs: help="Disable the NaN detection for better performance.", ) parser.add_argument( - "--enable-overlap-schedule", + "--disable-overlap-schedule", action="store_true", - help="Overlap the CPU scheduler with GPU model worker. Experimental feature.", + help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.", ) parser.add_argument( "--enable-mixed-chunk", @@ -692,6 +697,11 @@ class ServerArgs: ) # Deprecated arguments + parser.add_argument( + "--enable-overlap-schedule", + action=DeprecatedAction, + help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.", + ) parser.add_argument( "--disable-flashinfer", action=DeprecatedAction, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 12cbbd883..07f666e30 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -670,7 +670,7 @@ def run_and_check_memory_leak( workload_func, disable_radix_cache, enable_mixed_chunk, - enable_overlap, + disable_overlap, chunked_prefill_size, ): other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] @@ -678,8 +678,8 @@ def run_and_check_memory_leak( other_args += ["--disable-radix-cache"] if enable_mixed_chunk: other_args += ["--enable-mixed-chunk"] - if enable_overlap: - other_args += ["--enable-overlap-schedule"] + if disable_overlap: + other_args += ["--disable-overlap-schedule"] model = DEFAULT_MODEL_NAME_FOR_TEST port = random.randint(4000, 5000) @@ -731,7 +731,7 @@ def run_and_check_memory_leak( def run_mmlu_test( disable_radix_cache=False, enable_mixed_chunk=False, - enable_overlap=False, + disable_overlap=False, chunked_prefill_size=32, ): def workload_func(base_url, model): @@ -754,7 +754,7 @@ def run_mmlu_test( workload_func, disable_radix_cache, enable_mixed_chunk, - enable_overlap, + disable_overlap, chunked_prefill_size, ) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 7f343a15a..b857aec51 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,8 +17,8 @@ suites = { "test_json_constrained.py", "test_large_max_new_tokens.py", "test_metrics.py", + "test_non_overlap_scheduler.py", "test_openai_server.py", - "test_overlap_schedule.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", "test_retract_decode.py", diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index ff4758633..4584a86e0 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase): if is_in_ci(): self.assertLess(res["median_e2e_latency_ms"], 12000) - self.assertLess(res["median_ttft_ms"], 80) - self.assertLess(res["median_itl_ms"], 11) + self.assertLess(res["median_ttft_ms"], 86) + self.assertLess(res["median_itl_ms"], 10) def test_moe_offline_throughput_default(self): res = run_bench_serving( diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 41d9b0c90..2d08d6684 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase): self.assertIsInstance(js_obj["population"], int) # Make sure jump forward is triggered - self.assertGreater( - ret["meta_info"]["completion_tokens"], - ret["meta_info"]["completion_tokens_wo_jump_forward"], - ) + # NOTE: This is skipped because overlap scheduler does not support jump forward + # self.assertGreater( + # ret["meta_info"]["completion_tokens"], + # ret["meta_info"]["completion_tokens_wo_jump_forward"], + # ) def test_json_generate(self): self.run_decode(json_schema=self.json_schema) diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index 993e85a81..9880a8162 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.41) + self.assertGreater(metrics["score"], 0.40) def test_mgsm_en(self): args = SimpleNamespace( diff --git a/test/srt/test_overlap_schedule.py b/test/srt/test_non_overlap_scheduler.py similarity index 79% rename from test/srt/test_overlap_schedule.py rename to test/srt/test_non_overlap_scheduler.py index 367d2acc8..341207148 100644 --- a/test/srt/test_overlap_schedule.py +++ b/test/srt/test_non_overlap_scheduler.py @@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test class TestOverlapSchedule(unittest.TestCase): def test_no_radix_attention_chunked_prefill(self): run_mmlu_test( - disable_radix_cache=True, chunked_prefill_size=32, enable_overlap=True + disable_radix_cache=True, chunked_prefill_size=32, disable_overlap=True ) def test_no_radix_attention_no_chunked_prefill(self): run_mmlu_test( - disable_radix_cache=True, chunked_prefill_size=-1, enable_overlap=True + disable_radix_cache=True, chunked_prefill_size=-1, disable_overlap=True ) def test_radix_attention_chunked_prefill(self): run_mmlu_test( - disable_radix_cache=False, chunked_prefill_size=32, enable_overlap=True + disable_radix_cache=False, chunked_prefill_size=32, disable_overlap=True ) def test_radix_attention_no_chunked_prefill(self): run_mmlu_test( - disable_radix_cache=False, chunked_prefill_size=-1, enable_overlap=True + disable_radix_cache=False, chunked_prefill_size=-1, disable_overlap=True ) diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py index f9da49a1d..cdba7573d 100644 --- a/test/srt/test_radix_attention.py +++ b/test/srt/test_radix_attention.py @@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS): ) -class TestRadixCacheOverlapLPM(TestRadixCacheFCFS): +class TestRadixCacheNonOverlapLPM(TestRadixCacheFCFS): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ - "--enable-overlap-schedule", + "--disable-overlap-schedule", "--chunked-prefill-size", "128", "--max-total-tokens", diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index ddb92a57f..76945f963 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -1,3 +1,4 @@ +import time import unittest from types import SimpleNamespace @@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase): return response.json() def test_throughput(self): - import time + # Warmup + res = self.run_decode(16) max_tokens = 256 - tic = time.time() res = self.run_decode(max_tokens) tok = time.time() - print(res["text"]) + print(f"{res=}") throughput = max_tokens / (tok - tic) print(f"Throughput: {throughput} tokens/s") self.assertGreaterEqual(throughput, 152) diff --git a/test/srt/test_torch_compile_moe.py b/test/srt/test_torch_compile_moe.py index d19ab2bbd..e744e6686 100644 --- a/test/srt/test_torch_compile_moe.py +++ b/test/srt/test_torch_compile_moe.py @@ -1,3 +1,4 @@ +import time import unittest from types import SimpleNamespace @@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase): return response.json() def test_throughput(self): - import time + # Warmup + res = self.run_decode(16) max_tokens = 256 - tic = time.time() res = self.run_decode(max_tokens) tok = time.time()