diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index 1e711e647..730c79495 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: max_seqlen_pad = triton.cdiv( - forward_batch.seq_lens.max().item(), PAGE_SIZE + forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE ) block_kv_indices = torch.full( (bs, max_seqlen_pad), @@ -206,8 +206,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ): if forward_mode.is_decode_or_idle(): + assert seq_lens_cpu is not None seq_lens = seq_lens[:bs] - max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + seq_lens_cpu = seq_lens_cpu[:bs] + max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices[:bs],