From 90532b762777302cd46a9a38b667570360661e23 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 18 Mar 2025 21:26:53 -0700 Subject: [PATCH] [Fix] Fix raw_bs bug when using flashinfer mla and eagle (#4557) --- .../srt/speculative/eagle_draft_cuda_graph_runner.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 88ee3a486..323f47de9 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -52,6 +52,9 @@ class EAGLEDraftCudaGraphRunner: self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ 0 ].get_cuda_graph_seq_len_fill_value() + self.seq_lens_cpu = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) if self.enable_torch_compile: set_torch_compile_config() @@ -210,6 +213,12 @@ class EAGLEDraftCudaGraphRunner: forward_batch.req_pool_indices = self.req_pool_indices[:bs] forward_batch.positions = self.positions[:num_tokens] + # Special handle for seq_len_cpu used when flashinfer mla is used + if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs): + self.seq_lens_cpu.fill_(1) + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) + forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs] + self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( forward_batch, bs ) @@ -224,5 +233,7 @@ class EAGLEDraftCudaGraphRunner: forward_batch.positions = self.positions[:raw_num_token] forward_batch.seq_lens = self.seq_lens[:raw_bs] forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs] + if forward_batch.decode_seq_lens_cpu is not None: + forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs] return out