From 57eec0bfbce964e347ef2affb999e03416f22325 Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Tue, 25 Mar 2025 12:06:58 +0800 Subject: [PATCH] fix FlashMLA cudagraph config (#4691) Co-authored-by: yinfan98 <1106310035@qq.com> --- python/sglang/srt/layers/attention/flashmla_backend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index 1e711e647..730c79495 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: max_seqlen_pad = triton.cdiv( - forward_batch.seq_lens.max().item(), PAGE_SIZE + forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE ) block_kv_indices = torch.full( (bs, max_seqlen_pad), @@ -206,8 +206,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ): if forward_mode.is_decode_or_idle(): + assert seq_lens_cpu is not None seq_lens = seq_lens[:bs] - max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + seq_lens_cpu = seq_lens_cpu[:bs] + max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices[:bs],