From 60fdad7cf343333e956a3889c12956396a1516bf Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 6 Jun 2025 18:23:52 -0700 Subject: [PATCH] Sync the changes on cuda graph runners (#6932) --- python/sglang/srt/managers/io_struct.py | 4 +- .../srt/model_executor/cuda_graph_runner.py | 51 ++++++++----------- python/sglang/srt/server_args.py | 15 +++++- .../srt/speculative/build_eagle_tree.py | 16 +++--- .../eagle_draft_cuda_graph_runner.py | 11 ++-- .../eagle_draft_extend_cuda_graph_runner.py | 12 ++--- python/sglang/srt/speculative/eagle_utils.py | 24 +++++---- 7 files changed, 63 insertions(+), 70 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f13b23b17..2e73bf2a4 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -20,7 +20,7 @@ import copy import uuid from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from sglang.srt.mm_utils import has_valid_data @@ -30,7 +30,7 @@ if TYPE_CHECKING: else: Image = Any -from sglang.srt.managers.schedule_batch import BaseFinishReason, flatten_nested_list +from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 990452539..d443baa1c 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -259,23 +259,8 @@ class CudaGraphRunner: } # Speculative_inference - if ( - model_runner.spec_algorithm.is_eagle3() - and not model_runner.is_draft_worker - ): - self.hidden_states = torch.zeros( - ( - self.max_num_token, - 3 * self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) + if model_runner.spec_algorithm.is_eagle3(): self.model_runner.model.set_eagle3_layers_to_capture() - elif model_runner.spec_algorithm.is_eagle(): - self.hidden_states = torch.zeros( - (self.max_num_token, self.model_runner.model_config.hidden_size), - dtype=self.model_runner.dtype, - ) if self.is_encoder_decoder: # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch @@ -284,6 +269,7 @@ class CudaGraphRunner: ) else: self.encoder_lens = None + if self.enable_dp_attention or self.enable_sp_layernorm: # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer self.gathered_buffer = torch.zeros( @@ -303,13 +289,7 @@ class CudaGraphRunner: self.capture() except RuntimeError as e: raise Exception( - f"Capture CUDA graph failed: {e}\n" - "Possible solutions:\n" - "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" - "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" - "3. disable torch compile by not using --enable-torch-compile\n" - "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" - "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) @contextmanager @@ -439,6 +419,7 @@ class CudaGraphRunner: self.capture_hidden_mode = ( spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL ) + if self.model_runner.server_args.lora_paths is not None: # Currently, if the lora_path in `lora_paths` is None, the lora backend will use a # different logic to handle lora, so we need to set `lora_paths` to a list of non-None @@ -467,9 +448,9 @@ class CudaGraphRunner: spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=self.capture_hidden_mode, - lora_paths=lora_paths, num_token_non_padded=self.num_token_non_padded, global_forward_mode=self.capture_forward_mode, + lora_paths=lora_paths, ) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) @@ -497,7 +478,9 @@ class CudaGraphRunner: self.pp_size > 1 and "pp_proxy_tensors" in inspect.signature(forward).parameters ): - kwargs["pp_proxy_tensors"] = pp_proxy_tensors + kwargs["pp_proxy_tensors"] = PPProxyTensors( + {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} + ) logits_output_or_pp_proxy_tensors = forward( input_ids, @@ -590,9 +573,6 @@ class CudaGraphRunner: if self.enable_dp_attention or self.enable_sp_layernorm: self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) - if hasattr(forward_batch.spec_info, "hidden_states"): - self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states - # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, @@ -650,7 +630,7 @@ class CudaGraphRunner: else: spec_info = EagleVerifyInput( draft_token=None, - custom_mask=torch.zeros( + custom_mask=torch.ones( (num_tokens * self.model_runner.model_config.context_len), dtype=torch.bool, device="cuda", @@ -660,9 +640,20 @@ class CudaGraphRunner: retrive_next_token=None, retrive_next_sibling=None, retrive_cum_len=None, - draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, spec_steps=self.model_runner.server_args.speculative_num_steps, + topk=self.model_runner.server_args.speculative_eagle_topk, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, capture_hidden_mode=CaptureHiddenMode.FULL, ) return spec_info + + +CUDA_GRAPH_CAPTURE_FAILED_MSG = ( + "Possible solutions:\n" + "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" +) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 97229b7f1..317df60a6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -447,7 +447,7 @@ class ServerArgs: self.speculative_num_steps, self.speculative_eagle_topk, self.speculative_num_draft_tokens, - ) = auto_choose_speculative_params(model_arch) + ) = auto_choose_speculative_params(self) if self.page_size > 1 and self.speculative_eagle_topk > 1: self.speculative_eagle_topk = 1 @@ -1655,12 +1655,23 @@ def get_model_arch(args: ServerArgs): return hf_config.architectures[0] -def auto_choose_speculative_params(arch: str): +def auto_choose_speculative_params(self: ServerArgs): """ Automatically choose the parameters for speculative decoding. You can tune them on your own models and prompts with scripts/playground/bench_speculative.py """ + kwargs = {} + + hf_config = get_config( + self.model_path, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + model_override_args=json.loads(self.json_model_override_args), + **kwargs, + ) + arch = hf_config.architectures[0] + if arch in ["LlamaForCausalLM"]: # The default value for llama return (5, 4, 8) diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index c2840c4a0..1f0e0fcb7 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -4,7 +4,7 @@ from typing import List import torch -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import is_cuda, is_hip, rank0_print if is_cuda() or is_hip(): from sgl_kernel import ( @@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient(): num_verify_tokens=num_draft_token, ) - first_rank_print("=========== build tree kernel efficient ==========") - # first_rank_print(f"{tree_mask=}", flush=True) - first_rank_print(f"{position=}", flush=True) - first_rank_print(f"{retrive_index=}", flush=True) - first_rank_print(f"{retrive_next_token=}", flush=True) - first_rank_print(f"{retrive_next_sibling=}", flush=True) - first_rank_print(f"{draft_tokens=}", flush=True) + rank0_print("=========== build tree kernel efficient ==========") + # rank0_print(f"{tree_mask=}", flush=True) + rank0_print(f"{position=}", flush=True) + rank0_print(f"{retrive_index=}", flush=True) + rank0_print(f"{retrive_next_token=}", flush=True) + rank0_print(f"{retrive_next_sibling=}", flush=True) + rank0_print(f"{draft_tokens=}", flush=True) assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14] assert retrive_index.tolist() == [ [0, 1, 2, 3, 4, 5, 6, 7], diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 736a0f074..8336af2aa 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable import torch from sglang.srt.model_executor.cuda_graph_runner import ( + CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, get_batch_sizes_to_capture, get_global_graph_memory_pool, @@ -73,7 +74,7 @@ class EAGLEDraftCudaGraphRunner: self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) self.hidden_states = torch.zeros( - (self.max_bs, self.model_runner.model_config.hidden_size), + (self.max_num_token, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype, ) @@ -82,13 +83,7 @@ class EAGLEDraftCudaGraphRunner: self.capture() except RuntimeError as e: raise Exception( - f"Capture CUDA graph failed: {e}\n" - "Possible solutions:\n" - "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" - "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" - "3. disable torch compile by not using --enable-torch-compile\n" - "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" - "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) def can_run(self, forward_batch: ForwardBatch): diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index e817196e6..6894d4df2 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable import torch from sglang.srt.model_executor.cuda_graph_runner import ( + CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, LogitsProcessorOutput, get_batch_sizes_to_capture, @@ -89,13 +90,7 @@ class EAGLEDraftExtendCudaGraphRunner: self.capture() except RuntimeError as e: raise Exception( - f"Capture CUDA graph failed: {e}\n" - "Possible solutions:\n" - "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" - "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" - "3. disable torch compile by not using --enable-torch-compile\n" - "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" - "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) def can_run(self, forward_batch: ForwardBatch): @@ -200,7 +195,6 @@ class EAGLEDraftExtendCudaGraphRunner: # in the batch, which will not be counted as num_seqs raw_bs = forward_batch.batch_size num_tokens = forward_batch.input_ids.shape[0] - assert raw_bs * self.num_tokens_per_bs == num_tokens index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] @@ -224,9 +218,9 @@ class EAGLEDraftExtendCudaGraphRunner: self.seq_lens_cpu.fill_(1) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) - forward_batch.spec_info.positions = None if bs != raw_bs: forward_batch.spec_info.accept_length = self.accept_length[:bs] + forward_batch.spec_info.positions = None self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( bs=bs, diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 2140d1ac1..23fa1a2ed 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -232,8 +232,9 @@ class EagleVerifyInput: retrive_next_token: torch.Tensor retrive_next_sibling: torch.Tensor retrive_cum_len: torch.Tensor - draft_token_num: int spec_steps: int + topk: int + draft_token_num: int capture_hidden_mode: CaptureHiddenMode grammar: BaseGrammarObject = None @@ -270,16 +271,17 @@ class EagleVerifyInput: ) return cls( - draft_tokens, - tree_mask, - position, - retrive_index, - retrive_next_token, - retrive_next_sibling, - None, - num_verify_tokens, - spec_steps, - CaptureHiddenMode.FULL, + draft_token=draft_tokens, + custom_mask=tree_mask, + positions=position, + retrive_index=retrive_index, + retrive_next_token=retrive_next_token, + retrive_next_sibling=retrive_next_sibling, + retrive_cum_len=None, + spec_steps=spec_steps, + topk=topk, + draft_token_num=num_verify_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, ) def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):