diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index 8b0923ce3..afa03434c 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -157,7 +157,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ): if forward_mode.is_decode_or_idle(): if spec_info is None: - max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + max_seqlen_pad = self.cuda_graph_kv_indices.shape[1] create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, @@ -169,12 +169,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): self.cuda_graph_kv_indices.stride(0), PAGE_SIZE, ) - workspace_size = cutlass_mla_get_workspace_size( - max_seqlen_pad * PAGE_SIZE, bs - ) - self.cuda_graph_mla_workspace = torch.empty( - workspace_size, device="cuda", dtype=torch.uint8 - ) self.forward_metadata = CutlassMLADecodeMetadata( self.cuda_graph_mla_workspace, self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], @@ -205,8 +199,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): if forward_mode.is_decode_or_idle(): assert seq_lens_cpu is not None seq_lens = seq_lens[:bs] - seq_lens_cpu = seq_lens_cpu[:bs] - max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices[:bs], @@ -217,16 +210,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): self.cuda_graph_kv_indices.stride(0), PAGE_SIZE, ) - workspace_size = cutlass_mla_get_workspace_size( - max_seqlen_pad * PAGE_SIZE, bs - ) - self.cuda_graph_mla_workspace = torch.empty( - workspace_size, device="cuda", dtype=torch.uint8 - ) - self.forward_metadata.workspace = self.cuda_graph_mla_workspace - self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ - :bs, :max_seqlen_pad - ] else: super().init_forward_metadata_replay_cuda_graph( bs,