diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 6cbca78e9..e5210e88c 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend): def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for sequence length in CUDA graph.""" - return 0 + return 1 def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device): """Centralized utility to initialize local_attn_metadata if chunked attention is enabled.""" diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index f11de5641..316ad18b0 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend): raise ValueError("Invalid forward mode") def get_cuda_graph_seq_len_fill_value(self): - return 0 + return 1 def forward_extend( self, diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 19fa09818..c6beb5820 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): raise ValueError(f"Invalid forward mode: {forward_mode=}") def get_cuda_graph_seq_len_fill_value(self): - return 0 + return 1 def forward_extend( self, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 2ab5ee385..1c534845d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -612,7 +612,7 @@ class CudaGraphRunner: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: - self.seq_lens.fill_(1) + self.seq_lens.fill_(self.seq_len_fill_value) self.out_cache_loc.zero_() # Common inputs @@ -624,7 +624,7 @@ class CudaGraphRunner: if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: - self.seq_lens_cpu.fill_(1) + self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) if pp_proxy_tensors: @@ -652,7 +652,7 @@ class CudaGraphRunner: bs, self.req_pool_indices, self.seq_lens, - forward_batch.seq_lens_sum + (bs - raw_bs), + forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, self.encoder_lens, forward_batch.forward_mode, forward_batch.spec_info, diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 001404bd8..5e70799e5 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -187,9 +187,8 @@ class EAGLEDraftCudaGraphRunner: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: - self.seq_lens.fill_(1) + self.seq_lens.fill_(self.seq_len_fill_value) self.out_cache_loc.zero_() - self.positions.zero_() num_tokens = bs * self.num_tokens_per_bs @@ -211,15 +210,15 @@ class EAGLEDraftCudaGraphRunner: forward_batch.req_pool_indices = self.req_pool_indices[:bs] forward_batch.positions = self.positions[:num_tokens] - # Special handle for seq_len_cpu used when flashinfer mla is used if forward_batch.seq_lens_cpu is not None and bs != raw_bs: - self.seq_lens_cpu.fill_(1) + self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( forward_batch, bs ) + # TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph # Replay self.graphs[bs].replay() diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index b8fb11974..2d2fce197 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -207,9 +207,9 @@ class EAGLEDraftExtendCudaGraphRunner: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs * self.num_tokens_per_bs != num_tokens: - self.seq_lens.fill_(1) - self.accept_length.fill_(1) + self.seq_lens.fill_(self.seq_len_fill_value) self.out_cache_loc.zero_() + self.accept_length.fill_(1) # Common inputs self.input_ids[:num_tokens].copy_(forward_batch.input_ids) @@ -223,18 +223,19 @@ class EAGLEDraftExtendCudaGraphRunner: if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: - self.seq_lens_cpu.fill_(1) + self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) if bs != raw_bs: + forward_batch.spec_info.positions = self.positions[:num_tokens] forward_batch.spec_info.accept_length = self.accept_length[:bs] - forward_batch.spec_info.positions = None self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( bs=bs, req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens, - seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs), + seq_lens_sum=forward_batch.seq_lens_sum + + (bs - raw_bs) * self.seq_len_fill_value, encoder_lens=None, forward_mode=ForwardMode.DRAFT_EXTEND, spec_info=forward_batch.spec_info, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 83bea359c..c9157c225 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -166,6 +166,10 @@ class EAGLEWorker(TpModelWorker): def init_attention_backend(self): # Create multi-step attn backends and cuda graph runners + + self.has_prefill_wrapper_verify = False + self.draft_extend_attn_backend = None + if self.server_args.attention_backend == "flashinfer": if not global_server_args_dict["use_mla_backend"]: from sglang.srt.layers.attention.flashinfer_backend import ( @@ -213,7 +217,6 @@ class EAGLEWorker(TpModelWorker): self.draft_model_runner, skip_prefill=False, ) - self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "fa3": from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionBackend, @@ -229,7 +232,6 @@ class EAGLEWorker(TpModelWorker): self.draft_model_runner, skip_prefill=False, ) - self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "flashmla": from sglang.srt.layers.attention.flashmla_backend import ( FlashMLAMultiStepDraftBackend, @@ -240,8 +242,6 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, ) - self.draft_extend_attn_backend = None - self.has_prefill_wrapper_verify = False else: raise ValueError( f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"