diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 0aa3a695e..2bedcf077 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -155,6 +155,9 @@ class TritonAttnBackend(AttentionBackend): seq_lens: torch.Tensor, ): num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0] + # NOTE(alcanderian): Considering speculative_decodeing, + # num_kv_splits.shape[0] will be topk * real_num_token. + # And the real_num_token is num_seq in decoding phase. num_group = num_token // num_seq assert (