From fba8eccd7ebe41bbdbf70ab3b6a2df1835f8b532 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 12 May 2025 00:17:33 -0700 Subject: [PATCH] Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201) Co-authored-by: SangBin Cho --- python/sglang/bench_offline_throughput.py | 4 +- python/sglang/bench_one_batch.py | 4 +- python/sglang/bench_one_batch_server.py | 160 ++++++++++++++++-- python/sglang/bench_serving.py | 12 +- .../srt/constrained/base_grammar_backend.py | 6 + python/sglang/srt/disaggregation/prefill.py | 4 +- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/entrypoints/http_server.py | 2 +- python/sglang/srt/layers/attention/utils.py | 6 +- python/sglang/srt/managers/scheduler.py | 28 +-- .../scheduler_output_processor_mixin.py | 18 +- .../sglang/srt/managers/tokenizer_manager.py | 7 +- python/sglang/srt/managers/tp_worker.py | 21 ++- .../srt/managers/tp_worker_overlap_thread.py | 24 ++- .../srt/model_executor/cuda_graph_runner.py | 17 +- .../sglang/srt/model_executor/model_runner.py | 15 +- python/sglang/srt/server_args.py | 2 +- python/sglang/srt/speculative/eagle_worker.py | 27 +-- python/sglang/test/test_utils.py | 14 +- test/srt/run_suite.py | 2 +- test/srt/test_eagle_infer.py | 8 +- test/srt/test_fa3.py | 4 +- test/srt/test_full_deepseek_v3.py | 4 +- test/srt/test_mla_deepseek_v3.py | 8 +- test/srt/test_mla_flashinfer.py | 4 +- test/srt/test_mla_int8_deepseek_v3.py | 8 +- test/srt/test_srt_endpoint.py | 3 - 27 files changed, 293 insertions(+), 121 deletions(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index d88d535fc..5827e83b9 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -259,7 +259,9 @@ def throughput_test_once( measurement_results["total_input_tokens"] + measurement_results["total_output_tokens"] ) / latency - measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"] + measurement_results["last_gen_throughput"] = server_info["internal_states"][0][ + "last_gen_throughput" + ] return measurement_results diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 09da170a8..f8c67c8f4 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -246,7 +246,7 @@ def extend(reqs, model_runner): _maybe_prepare_dp_attn_batch(batch, model_runner) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) - logits_output = model_runner.forward(forward_batch) + logits_output, _ = model_runner.forward(forward_batch) next_token_ids = model_runner.sample(logits_output, forward_batch) return next_token_ids, logits_output.next_token_logits, batch @@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner): _maybe_prepare_dp_attn_batch(batch, model_runner) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) - logits_output = model_runner.forward(forward_batch) + logits_output, _ = model_runner.forward(forward_batch) next_token_ids = model_runner.sample(logits_output, forward_batch) return next_token_ids, logits_output.next_token_logits diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index 1fc4ff58d..73ee8dc9f 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -25,6 +25,7 @@ import requests from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import is_in_ci, write_github_step_summary @dataclasses.dataclass @@ -33,9 +34,13 @@ class BenchArgs: batch_size: Tuple[int] = (1,) input_len: Tuple[int] = (1024,) output_len: Tuple[int] = (16,) + temperature: float = 0.0 + return_logprob: bool = False + input_len_step_percentage: float = 0.0 result_filename: str = "result.jsonl" base_url: str = "" skip_warmup: bool = False + show_report: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -49,11 +54,19 @@ class BenchArgs: parser.add_argument( "--output-len", type=int, nargs="+", default=BenchArgs.output_len ) + parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) + parser.add_argument("--return-logprob", action="store_true") + parser.add_argument( + "--input-len-step-percentage", + type=float, + default=BenchArgs.input_len_step_percentage, + ) parser.add_argument( "--result-filename", type=str, default=BenchArgs.result_filename ) parser.add_argument("--base-url", type=str, default=BenchArgs.base_url) parser.add_argument("--skip-warmup", action="store_true") + parser.add_argument("--show-report", action="store_true") @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -99,36 +112,89 @@ def run_one_case( batch_size: int, input_len: int, output_len: int, + temperature: float, + return_logprob: bool, + input_len_step_percentage: float, run_name: str, result_filename: str, ): - input_ids = [ - [int(x) for x in np.random.randint(0, high=16384, size=(input_len,))] - for _ in range(batch_size) + requests.post(url + "/flush_cache") + input_lens = [ + int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage)) + for i in range(batch_size) ] + input_ids = [ + [int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))] + for i in range(batch_size) + ] + + use_structured_outputs = False + if use_structured_outputs: + texts = [] + for _ in range(batch_size): + texts.append( + "Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n" + * 50 + + "Assistant:" + ) + json_schema = "$$ANY$$" + else: + json_schema = None tic = time.time() response = requests.post( url + "/generate", json={ + # "text": texts, "input_ids": input_ids, "sampling_params": { - "temperature": 0, + "temperature": temperature, "max_new_tokens": output_len, "ignore_eos": True, + "json_schema": json_schema, }, + "return_logprob": return_logprob, + "stream": True, }, + stream=True, ) - latency = time.time() - tic - _ = response.json() - output_throughput = batch_size * output_len / latency + # The TTFT of the last request in the batch + ttft = 0.0 + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) + if "error" in data: + raise RuntimeError(f"Request has failed. {data}.") + + assert ( + data["meta_info"]["finish_reason"] is None + or data["meta_info"]["finish_reason"]["type"] == "length" + ) + if data["meta_info"]["completion_tokens"] == 1: + ttft = time.time() - tic + + latency = time.time() - tic + input_throughput = batch_size * input_len / ttft + output_throughput = batch_size * output_len / (latency - ttft) overall_throughput = batch_size * (input_len + output_len) / latency + server_info = requests.get(url + "/get_server_info").json() + acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None) + last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"] + print(f"batch size: {batch_size}") + print(f"input_len: {input_len}") + print(f"output_len: {output_len}") print(f"latency: {latency:.2f} s") - print(f"output throughput: {output_throughput:.2f} token/s") - print(f"(input + output) throughput: {overall_throughput:.2f} token/s") + print(f"ttft: {ttft:.2f} s") + print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s") + print(f"Input throughput: {input_throughput:.2f} tok/s") + if output_len != 1: + print(f"output throughput: {output_throughput:.2f} tok/s") if result_filename: with open(result_filename, "a") as fout: @@ -140,9 +206,21 @@ def run_one_case( "latency": round(latency, 4), "output_throughput": round(output_throughput, 2), "overall_throughput": round(overall_throughput, 2), + "last_gen_throughput": round(last_gen_throughput, 2), } fout.write(json.dumps(res) + "\n") + return ( + batch_size, + latency, + ttft, + input_throughput, + output_throughput, + overall_throughput, + last_gen_throughput, + acc_length, + ) + def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): if bench_args.base_url: @@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): # warmup if not bench_args.skip_warmup: + print("=" * 8 + " Warmup Begin " + "=" * 8) run_one_case( base_url, batch_size=16, input_len=1024, output_len=16, + temperature=bench_args.temperature, + return_logprob=bench_args.return_logprob, + input_len_step_percentage=bench_args.input_len_step_percentage, run_name="", result_filename="", ) + print("=" * 8 + " Warmup End " + "=" * 8 + "\n") # benchmark + result = [] try: for bs, il, ol in itertools.product( bench_args.batch_size, bench_args.input_len, bench_args.output_len ): - run_one_case( - base_url, - bs, - il, - ol, - bench_args.run_name, - bench_args.result_filename, + result.append( + run_one_case( + base_url, + bs, + il, + ol, + temperature=bench_args.temperature, + return_logprob=bench_args.return_logprob, + input_len_step_percentage=bench_args.input_len_step_percentage, + run_name=bench_args.run_name, + result_filename=bench_args.result_filename, + ) ) finally: if proc: @@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): print(f"\nResults are saved to {bench_args.result_filename}") + if not bench_args.show_report: + return + + summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n" + summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n" + + for ( + batch_size, + latency, + ttft, + input_throughput, + output_throughput, + overall_throughput, + last_gen_throughput, + acc_length, + ) in result: + hourly_cost = 2 * server_args.tp_size # $2/hour for one H100 + input_util = 0.7 + accept_length = round(acc_length, 2) if acc_length is not None else "n/a" + line = ( + f"| {batch_size} | " + f"{latency:.2f} | " + f"{input_throughput:.2f} | " + f"{output_throughput:.2f} | " + f"{accept_length} | " + f"{1 / (output_throughput/batch_size) * 1000:.2f} | " + f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | " + f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n" + ) + summary += line + + # print metrics table + print(summary) + + if is_in_ci(): + write_github_step_summary( + f"### Test Nightly Benchmark (bench_one_batch) \n{summary}" + ) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index d6133d437..84d88e136 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -1103,7 +1103,7 @@ async def benchmark( lora_names: List[str], extra_request_body: Dict[str, Any], profile: bool, - pd_seperated: bool = False, + pd_separated: bool = False, flush_cache: bool = False, warmup_requests: int = 1, ): @@ -1239,12 +1239,14 @@ async def benchmark( if "sglang" in backend: server_info = requests.get(base_url + "/get_server_info") - if pd_seperated: - accept_length = server_info.json()["decode"][0].get( + if pd_separated: + accept_length = server_info.json()["decode"][0]["internal_states"][0].get( "avg_spec_accept_length", None ) else: - accept_length = server_info.json().get("avg_spec_accept_length", None) + accept_length = server_info.json()["internal_states"][0].get( + "avg_spec_accept_length", None + ) else: accept_length = None @@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace): lora_names=args.lora_name, extra_request_body=extra_request_body, profile=args.profile, - pd_seperated=args.pd_seperated, + pd_separated=args.pd_separated, flush_cache=args.flush_cache, ) ) diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index 0316a8dfc..f097d4c08 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -37,6 +37,12 @@ class BaseGrammarObject: """ raise NotImplementedError() + def rollback(self, k: int): + raise NotImplementedError() + + def is_terminated(self): + raise NotImplementedError() + def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device ) -> torch.Tensor: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 6204faca2..abcc707df 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -277,19 +277,17 @@ class SchedulerDisaggregationPrefillMixin: next_token_ids, extend_input_len_per_req, extend_logprob_start_len_per_req, - bid, ) = ( result.logits_output, result.next_token_ids, result.extend_input_len_per_req, result.extend_logprob_start_len_per_req, - result.bid, ) # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue if self.enable_overlap: # wait - _, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done) + _, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done) else: next_token_ids = result.next_token_ids.tolist() diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 6a6961f2f..f7b1c23fe 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -330,7 +330,7 @@ class Engine(EngineBase): return { **dataclasses.asdict(self.tokenizer_manager.server_args), **self.scheduler_info, - **internal_states, + "internal_states": internal_states, "version": __version__, } diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index fb030d02e..6d0672ac2 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -222,7 +222,7 @@ async def get_server_info(): return { **dataclasses.asdict(_global_state.tokenizer_manager.server_args), **_global_state.scheduler_info, - **internal_states, + "internal_states": internal_states, "version": __version__, } diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index c87aa45d7..816fbf08a 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton( num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) for i in range(num_loop): - offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + # index into req_to_token_ptr needs to be int64 + offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE mask = offset < kv_end - kv_start data = tl.load( req_to_token_ptr @@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton( num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) for i in range(num_pages_loop): + # index into req_to_token_ptr needs to be int64 paged_offset = ( - tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK + tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK ) * PAGED_SIZE paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4498f2cfc..fa1e20b0c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -160,6 +160,7 @@ class GenerationBatchResult: extend_input_len_per_req: List[int] extend_logprob_start_len_per_req: List[int] bid: int + can_run_cuda_graph: bool @dataclass @@ -323,13 +324,14 @@ class Scheduler( set_random_seed(self.random_seed) # Print debug info - logger.info( - f"max_total_num_tokens={self.max_total_num_tokens}, " - f"chunked_prefill_size={server_args.chunked_prefill_size}, " - f"max_prefill_tokens={self.max_prefill_tokens}, " - f"max_running_requests={self.max_running_requests}, " - f"context_len={self.model_config.context_len}" - ) + if tp_rank == 0: + logger.info( + f"max_total_num_tokens={self.max_total_num_tokens}, " + f"chunked_prefill_size={server_args.chunked_prefill_size}, " + f"max_prefill_tokens={self.max_prefill_tokens}, " + f"max_running_requests={self.max_running_requests}, " + f"context_len={self.model_config.context_len}" + ) # Init memory pool and cache self.init_memory_pool_and_cache() @@ -752,6 +754,7 @@ class Scheduler( extend_input_len_per_req=None, extend_logprob_start_len_per_req=None, bid=bids[next_mb_id], + can_run_cuda_graph=result.can_run_cuda_graph, ) self.process_batch_result(mbs[next_mb_id], output_result) last_mbs[next_mb_id] = mbs[next_mb_id] @@ -1159,7 +1162,9 @@ class Scheduler( self.metrics_collector.log_stats(self.stats) - def log_decode_stats(self, running_batch=None): + def log_decode_stats( + self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None + ): batch = running_batch or self.running_batch gap_latency = time.time() - self.last_decode_stats_tic @@ -1199,6 +1204,7 @@ class Scheduler( msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " msg += ( + f"cuda graph: {can_run_cuda_graph}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"#queue-req: {len(self.waiting_queue)}" ) @@ -1524,11 +1530,11 @@ class Scheduler( if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() if self.pp_group.is_last_rank: - logits_output, next_token_ids = ( + logits_output, next_token_ids, can_run_cuda_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) ) else: - pp_hidden_states_proxy_tensors, _ = ( + pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) ) bid = model_worker_batch.bid @@ -1538,6 +1544,7 @@ class Scheduler( next_token_ids, bid, num_accepted_tokens, + can_run_cuda_graph, ) = self.draft_worker.forward_batch_speculative_generation(batch) self.spec_num_total_accepted_tokens += ( num_accepted_tokens + batch.batch_size() @@ -1571,6 +1578,7 @@ class Scheduler( extend_input_len_per_req=extend_input_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, bid=bid, + can_run_cuda_graph=can_run_cuda_graph, ) else: # embedding or reward model model_worker_batch = batch.get_model_worker_batch() diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 859a5520b..d3b7c6f8e 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -38,20 +38,16 @@ class SchedulerOutputProcessorMixin: next_token_ids, extend_input_len_per_req, extend_logprob_start_len_per_req, - bid, ) = ( result.logits_output, result.next_token_ids, result.extend_input_len_per_req, result.extend_logprob_start_len_per_req, - result.bid, ) if self.enable_overlap: - logits_output, next_token_ids = ( - self.tp_worker.resolve_last_batch_result( - launch_done, - ) + logits_output, next_token_ids, _ = ( + self.tp_worker.resolve_last_batch_result(launch_done) ) else: # Move next_token_ids and logprobs to cpu @@ -189,16 +185,16 @@ class SchedulerOutputProcessorMixin: result: GenerationBatchResult, launch_done: Optional[threading.Event] = None, ): - logits_output, next_token_ids, bid = ( + logits_output, next_token_ids, can_run_cuda_graph = ( result.logits_output, result.next_token_ids, - result.bid, + result.can_run_cuda_graph, ) self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: - logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result( - launch_done + logits_output, next_token_ids, can_run_cuda_graph = ( + self.tp_worker.resolve_last_batch_result(launch_done) ) next_token_logprobs = logits_output.next_token_logprobs elif batch.spec_algorithm.is_none(): @@ -280,7 +276,7 @@ class SchedulerOutputProcessorMixin: self.attn_tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): - self.log_decode_stats(running_batch=batch) + self.log_decode_stats(can_run_cuda_graph, running_batch=batch) def add_input_logprob_return_values( self: Scheduler, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 05d1a54f4..db64dd0a2 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -923,12 +923,13 @@ class TokenizerManager: ): await self.send_to_scheduler.send_pyobj(obj) - async def get_internal_state(self) -> Dict[Any, Any]: + async def get_internal_state(self) -> List[Dict[Any, Any]]: req = GetInternalStateReq() - res: List[GetInternalStateReqOutput] = ( + responses: List[GetInternalStateReqOutput] = ( await self.get_internal_state_communicator(req) ) - return res[0].internal_state + # Many DP ranks + return [res.internal_state for res in responses] def get_log_request_metadata(self): max_length = None diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index faed34665..786a34a1e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union import torch from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group +from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.hf_transformers_utils import ( get_processor, get_tokenizer, @@ -183,8 +183,11 @@ class TpModelWorker: def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, + launch_done: Optional[threading.Event] = None, skip_sample: bool = False, - ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]: + ) -> Tuple[ + Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool + ]: forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) pp_proxy_tensors = None @@ -196,11 +199,11 @@ class TpModelWorker: ) if self.pp_group.is_last_rank: - logits_output = self.model_runner.forward( + logits_output, can_run_cuda_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors ) - if model_worker_batch.launch_done is not None: - model_worker_batch.launch_done.set() + if launch_done is not None: + launch_done.set() if skip_sample: next_token_ids = None @@ -209,17 +212,17 @@ class TpModelWorker: logits_output, model_worker_batch ) - return logits_output, next_token_ids + return logits_output, next_token_ids, can_run_cuda_graph else: - pp_proxy_tensors = self.model_runner.forward( + pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors, ) - return pp_proxy_tensors.tensors, None + return pp_proxy_tensors.tensors, None, can_run_cuda_graph def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) + logits_output, _ = self.model_runner.forward(forward_batch) embeddings = logits_output.embeddings return embeddings diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 8bfcfe02f..4c6cd576f 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -18,7 +18,7 @@ import logging import signal import threading from queue import Queue -from typing import Optional +from typing import Optional, Tuple import psutil import torch @@ -145,8 +145,10 @@ class TpModelWorkerClient: resolve_future_token_ids(input_ids, self.future_token_ids_map) # Run forward - logits_output, next_token_ids = self.worker.forward_batch_generation( - model_worker_batch + logits_output, next_token_ids, can_run_cuda_graph = ( + self.worker.forward_batch_generation( + model_worker_batch, model_worker_batch.launch_done + ) ) # Update the future token ids map @@ -171,14 +173,18 @@ class TpModelWorkerClient: next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_done.record() - self.output_queue.put((copy_done, logits_output, next_token_ids)) + self.output_queue.put( + (copy_done, logits_output, next_token_ids, can_run_cuda_graph) + ) def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None): """ This function is called to resolve the last batch result and wait for the current batch to be launched. Used in overlap mode. """ - copy_done, logits_output, next_token_ids = self.output_queue.get() + copy_done, logits_output, next_token_ids, can_run_cuda_graph = ( + self.output_queue.get() + ) if launch_done is not None: launch_done.wait() @@ -193,9 +199,11 @@ class TpModelWorkerClient: logits_output.input_token_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() - return logits_output, next_token_ids + return logits_output, next_token_ids, can_run_cuda_graph - def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + def forward_batch_generation( + self, model_worker_batch: ModelWorkerBatch + ) -> Tuple[None, torch.Tensor, bool]: # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. sampling_info = model_worker_batch.sampling_info sampling_info.update_penalties() @@ -223,7 +231,7 @@ class TpModelWorkerClient: self.future_token_ids_ct = ( self.future_token_ids_ct + bs ) % self.future_token_ids_limit - return None, future_next_token_ids + return None, future_next_token_ids, False def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): success, message = self.worker.update_weights_from_disk(recv_req) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ba3882ac6..025c75392 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -19,7 +19,7 @@ import bisect import inspect import os from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Optional, Union import torch import tqdm @@ -40,15 +40,12 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.utils import ( get_available_gpu_memory, get_device_memory_capacity, - is_hip, rank0_log, ) if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner -_is_hip = is_hip() - def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): @@ -137,7 +134,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ) gpu_mem = get_device_memory_capacity() - # Batch size of each rank will not become so large when DP is on if gpu_mem is not None and gpu_mem > 96 * 1024: capture_bs += list(range(160, 257, 8)) @@ -148,12 +144,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): model_runner.req_to_token_pool.size ] - capture_bs = list(sorted(set(capture_bs))) - - assert len(capture_bs) > 0 and capture_bs[0] > 0 - capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] if server_args.cuda_graph_max_bs: capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] + if max(capture_bs) < server_args.cuda_graph_max_bs: + capture_bs += list( + range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16) + ) + capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] + capture_bs = list(sorted(set(capture_bs))) + assert len(capture_bs) > 0 and capture_bs[0] > 0 compile_bs = ( [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] if server_args.enable_torch_compile diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index deabf8265..a102e63ae 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1085,32 +1085,33 @@ class ModelRunner: forward_batch: ForwardBatch, skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: can_run_cuda_graph = bool( forward_batch.forward_mode.is_cuda_graph() and self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch) ) if can_run_cuda_graph: - return self.cuda_graph_runner.replay( + ret = self.cuda_graph_runner.replay( forward_batch, skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, ) - - if forward_batch.forward_mode.is_decode(): - return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors) + elif forward_batch.forward_mode.is_decode(): + ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors) elif forward_batch.forward_mode.is_extend(): - return self.forward_extend( + ret = self.forward_extend( forward_batch, skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, ) elif forward_batch.forward_mode.is_idle(): - return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) + ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") + return ret, can_run_cuda_graph + def _preprocess_logits( self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo ): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a780976e3..dafb5b5b3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1086,7 +1086,7 @@ class ServerArgs: "--cuda-graph-max-bs", type=int, default=ServerArgs.cuda_graph_max_bs, - help="Set the maximum batch size for cuda graph.", + help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.", ) parser.add_argument( "--cuda-graph-bs", diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 7c61307a4..7ea48102d 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker): if batch.forward_mode.is_decode(): with self.draft_tp_context(self.draft_model_runner.tp_group): spec_info = self.draft(batch) - logits_output, verify_output, model_worker_batch = self.verify( - batch, spec_info + logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( + self.verify(batch, spec_info) ) # If it is None, it means all requests are finished @@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker): verify_output.verified_id, model_worker_batch.bid, sum(verify_output.accept_length_per_req_cpu), + can_run_cuda_graph, ) elif batch.forward_mode.is_idle(): model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = self.target_worker.forward_batch_generation( - model_worker_batch + logits_output, next_token_ids, _ = ( + self.target_worker.forward_batch_generation(model_worker_batch) ) - return logits_output, next_token_ids, model_worker_batch.bid, 0 + return logits_output, next_token_ids, model_worker_batch.bid, 0, False else: logits_output, next_token_ids, bid = self.forward_target_extend(batch) with self.draft_tp_context(self.draft_model_runner.tp_group): self.forward_draft_extend( batch, logits_output.hidden_states, next_token_ids ) - return logits_output, next_token_ids, bid, 0 + return logits_output, next_token_ids, bid, 0, False def forward_target_extend( self, batch: ScheduleBatch @@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker): # We need the full hidden states to prefill the KV cache of the draft model. model_worker_batch = batch.get_model_worker_batch() model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL - logits_output, next_token_ids = self.target_worker.forward_batch_generation( + logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( model_worker_batch ) return logits_output, next_token_ids, model_worker_batch.bid @@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker): batch.forward_mode = ForwardMode.TARGET_VERIFY batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch() - logits_output, _ = self.target_worker.forward_batch_generation( - model_worker_batch, skip_sample=True + logits_output, _, can_run_cuda_graph = ( + self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True + ) ) self._detect_nan_if_needed(logits_output) spec_info.hidden_states = logits_output.hidden_states @@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker): if batch.return_logprob: self.add_logprob_values(batch, res, logits_output) - return logits_output, res, model_worker_batch + return logits_output, res, model_worker_batch, can_run_cuda_graph def add_logprob_values( self, @@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker): model_worker_batch, self.draft_model_runner ) forward_batch.return_logprob = False - logits_output = self.draft_model_runner.forward(forward_batch) + logits_output, _ = self.draft_model_runner.forward(forward_batch) self._detect_nan_if_needed(logits_output) assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info is batch.spec_info @@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker): ) # Run - logits_output = self.draft_model_runner.forward(forward_batch) + logits_output, _ = self.draft_model_runner.forward(forward_batch) self._detect_nan_if_needed(logits_output) self.capture_for_decode(logits_output, forward_batch.spec_info) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 150f385c9..6cc0717b0 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -395,12 +395,12 @@ def popen_launch_server( other_args: list[str] = (), env: Optional[dict] = None, return_stdout_stderr: Optional[tuple] = None, - pd_seperated: bool = False, + pd_separated: bool = False, ): _, host, port = base_url.split(":") host = host[2:] - if pd_seperated: + if pd_separated: command = "sglang.launch_pd_server" else: command = "sglang.launch_server" @@ -414,7 +414,7 @@ def popen_launch_server( *[str(x) for x in other_args], ] - if pd_seperated: + if pd_separated: command.extend( [ "--lb-host", @@ -656,7 +656,7 @@ def get_benchmark_args( disable_stream=False, disable_ignore_eos=False, seed: int = 0, - pd_seperated: bool = False, + pd_separated: bool = False, ): return SimpleNamespace( backend="sglang", @@ -686,7 +686,7 @@ def get_benchmark_args( profile=None, lora_name=None, prompt_suffix="", - pd_seperated=pd_seperated, + pd_separated=pd_separated, ) @@ -750,7 +750,7 @@ def run_bench_serving_multi( other_server_args, benchmark_args, need_warmup=False, - pd_seperated=False, + pd_separated=False, ): # Launch the server process = popen_launch_server( @@ -758,7 +758,7 @@ def run_bench_serving_multi( base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=other_server_args, - pd_seperated=pd_seperated, + pd_separated=pd_separated, ) # run benchmark for all diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a70679f50..2c910e57b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -101,8 +101,8 @@ suites = { # TestFile("test_deepep_intranode.py", 50), # TestFile("test_deepep_low_latency.py", 50), # TestFile("test_moe_deepep_eval_accuracy_large.py", 250), + # TestFile("test_disaggregation.py", 90), TestFile("test_local_attn.py", 250), - TestFile("test_disaggregation.py", 90), TestFile("test_full_deepseek_v3.py", 250), TestFile("test_pp_single_node.py", 150), ], diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 8cf89e14e..7f653777a 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -97,7 +97,9 @@ class TestEAGLEEngine(CustomTestCase): print(f"{engine.get_server_info()=}") - avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"] + avg_spec_accept_length = engine.get_server_info()["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 1.9) @@ -296,7 +298,9 @@ class TestEAGLEServer(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.20) server_info = requests.get(self.base_url + "/get_server_info").json() - avg_spec_accept_length = server_info["avg_spec_accept_length"] + avg_spec_accept_length = server_info["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") speculative_eagle_topk = server_info["speculative_eagle_topk"] diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 6946dd4f3..c43196571 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -111,7 +111,9 @@ class BaseFlashAttentionTest(CustomTestCase): if self.speculative_decode: server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold) diff --git a/test/srt/test_full_deepseek_v3.py b/test/srt/test_full_deepseek_v3.py index a223cdc3e..6d6f2eef4 100644 --- a/test/srt/test_full_deepseek_v3.py +++ b/test/srt/test_full_deepseek_v3.py @@ -118,7 +118,9 @@ class TestDeepseekV3MTP(CustomTestCase): print(f"{metrics=}") server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") if is_in_ci(): diff --git a/test/srt/test_mla_deepseek_v3.py b/test/srt/test_mla_deepseek_v3.py index 2502e2bd9..1990b0857 100644 --- a/test/srt/test_mla_deepseek_v3.py +++ b/test/srt/test_mla_deepseek_v3.py @@ -100,7 +100,9 @@ class TestDeepseekV3MTP(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.60) server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 2.5) @@ -159,7 +161,9 @@ class TestDeepseekV3MTPWithDraft(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.60) server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 2.5) diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index be53a16f9..aa971a582 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -158,7 +158,9 @@ class TestFlashinferMLAMTP(CustomTestCase): server_info = requests.get(self.base_url + "/get_server_info") print(f"{server_info=}") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 2.5) diff --git a/test/srt/test_mla_int8_deepseek_v3.py b/test/srt/test_mla_int8_deepseek_v3.py index 5e6dc62a4..38207cdf6 100644 --- a/test/srt/test_mla_int8_deepseek_v3.py +++ b/test/srt/test_mla_int8_deepseek_v3.py @@ -105,7 +105,9 @@ class TestDeepseekV3MTPChannelInt8(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.60) server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 2.5) @@ -199,7 +201,9 @@ class TestDeepseekV3MTPBlockInt8(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.60) server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 2.5) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 17e542156..401ad9202 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -492,9 +492,6 @@ class TestSRTEndpoint(CustomTestCase): max_total_num_tokens = response_json["max_total_num_tokens"] self.assertIsInstance(max_total_num_tokens, int) - attention_backend = response_json["attention_backend"] - self.assertIsInstance(attention_backend, str) - version = response_json["version"] self.assertIsInstance(version, str)