From e5ce395a6cb87aab3ae86c7678e3fee0a2c75a28 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 18 Feb 2025 23:03:26 +0800 Subject: [PATCH] Fix draft decode max batch size (#3676) --- python/sglang/srt/layers/attention/flashinfer_backend.py | 2 +- python/sglang/srt/layers/attention/triton_backend.py | 2 +- .../sglang/srt/layers/attention/triton_ops/decode_attention.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index a3e194ccb..e39bdd2d6 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1094,7 +1094,7 @@ class FlashInferMultiStepDraftBackend: self.topk = topk self.speculative_num_steps = speculative_num_steps self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices - max_bs = model_runner.req_to_token_pool.size + max_bs = model_runner.req_to_token_pool.size * self.topk self.kv_indptr = torch.zeros( ( self.speculative_num_steps, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 627a1db23..7bb6615ed 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -474,7 +474,7 @@ class TritonMultiStepDraftBackend: self.topk = topk self.speculative_num_steps = speculative_num_steps self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices - max_bs = model_runner.req_to_token_pool.size + max_bs = model_runner.req_to_token_pool.size * self.topk self.kv_indptr = torch.zeros( ( self.speculative_num_steps, diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index f2274322c..3b4853e40 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -635,6 +635,9 @@ def decode_attention_fwd( logit_cap=0.0, ): assert num_kv_splits == attn_logits.shape[2] + assert q.shape[0] <= kv_indptr.shape[0] - 1 + assert q.shape[0] <= attn_logits.shape[0] + kv_group_num = q.shape[1] // v_buffer.shape[1] if kv_group_num == 1: