From 84810da4ae424d15ec06a3e32a03a839dd50a644 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Sun, 27 Apr 2025 20:58:53 -0700 Subject: [PATCH] Add Cutlass MLA attention backend (#5390) --- docs/backend/server_arguments.md | 2 +- .../layers/attention/cutlass_mla_backend.py | 278 ++++++++++++++++++ python/sglang/srt/layers/attention/utils.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 7 + python/sglang/srt/server_args.py | 15 +- sgl-kernel/python/sgl_kernel/attention.py | 3 + 7 files changed, 305 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/layers/attention/cutlass_mla_backend.py diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 2550513b0..13ed5074c 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma ## Kernel backend -* `attention_backend`: This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend. +* `attention_backend`: This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, `cutlass_mla`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend. * `sampling_backend`: The backend for sampling. ## Constrained Decoding diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py new file mode 100644 index 000000000..9ff5dfabf --- /dev/null +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +""" +Support attention backend for Cutlass MLA. + +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import torch +import triton + +from sglang.global_config import global_config +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import is_cuda + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + from sglang.srt.speculative.spec_info import SpecInfo + +_is_cuda = is_cuda() +if _is_cuda: + from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size + + +# Cutlass MLA only supports pagesize=128 +PAGE_SIZE = 128 + + +@dataclass +class CutlassMLADecodeMetadata: + workspace: Optional[torch.Tensor] = None + block_kv_indices: Optional[torch.Tensor] = None + + def __init__( + self, + workspace: Optional[torch.Tensor] = None, + block_kv_indices: Optional[torch.Tensor] = None, + ): + self.workspace = workspace + self.block_kv_indices = block_kv_indices + + +class CutlassMLABackend(FlashInferMLAAttnBackend): + """Cutlass attention kernels.""" + + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + kv_last_page_len_buf: Optional[torch.Tensor] = None, + ): + super().__init__( + model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf + ) + + self.num_q_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.num_local_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.forward_metadata: Union[CutlassMLADecodeMetadata] = None + self.kv_lora_rank = model_runner.model_config.kv_lora_rank + self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim + self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim + self.v_head_dim = model_runner.model_config.v_head_dim + self.scaling = model_runner.model_config.scaling + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim + + def init_forward_metadata(self, forward_batch: ForwardBatch): + + bs = forward_batch.batch_size + spec_info = forward_batch.spec_info + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + max_seqlen_pad = triton.cdiv( + forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE + ) + block_kv_indices = torch.full( + (bs, max_seqlen_pad), + -1, + dtype=torch.int32, + device=forward_batch.seq_lens.device, + ) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + PAGE_SIZE, + ) + workspace_size = cutlass_mla_get_workspace_size( + max_seqlen_pad * PAGE_SIZE, bs + ) + workspace = torch.empty( + workspace_size, device="cuda", dtype=torch.uint8 + ) + self.forward_metadata = CutlassMLADecodeMetadata( + workspace, + block_kv_indices, + ) + else: + super().init_forward_metadata(forward_batch) + else: + super().init_forward_metadata(forward_batch) + + def init_cuda_graph_state( + self, + max_bs: int, + block_kv_indices: Optional[torch.Tensor] = None, + ): + if block_kv_indices is None: + cuda_graph_kv_indices = torch.full( + (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE), + 1, + dtype=torch.int32, + device="cuda", + ) + else: + cuda_graph_kv_indices = block_kv_indices + + workspace_size = cutlass_mla_get_workspace_size( + cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs + ) + self.cuda_graph_mla_workspace = torch.empty( + workspace_size, device="cuda", dtype=torch.uint8 + ) + self.cuda_graph_kv_indices = cuda_graph_kv_indices + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + ): + if forward_mode.is_decode_or_idle(): + if spec_info is None: + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + PAGE_SIZE, + ) + workspace_size = cutlass_mla_get_workspace_size( + max_seqlen_pad * PAGE_SIZE, bs + ) + self.cuda_graph_mla_workspace = torch.empty( + workspace_size, device="cuda", dtype=torch.uint8 + ) + self.forward_metadata = CutlassMLADecodeMetadata( + self.cuda_graph_mla_workspace, + self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], + ) + else: + super().init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], + ): + + if forward_mode.is_decode_or_idle(): + assert seq_lens_cpu is not None + seq_lens = seq_lens[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + PAGE_SIZE, + ) + workspace_size = cutlass_mla_get_workspace_size( + max_seqlen_pad * PAGE_SIZE, bs + ) + self.cuda_graph_mla_workspace = torch.empty( + workspace_size, device="cuda", dtype=torch.uint8 + ) + self.forward_metadata.workspace = self.cuda_graph_mla_workspace + self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ + :bs, :max_seqlen_pad + ] + else: + super().init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): + cache_loc = forward_batch.out_cache_loc + + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) + bs = forward_batch.batch_size + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + + o = cutlass_mla_decode( + q_nope_and_q_pe=reshape_q, + kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim), + seq_lens=forward_batch.seq_lens.to(torch.int32), + page_table=self.forward_metadata.block_kv_indices, + workspace=self.forward_metadata.workspace, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 29b64c24b..c87aa45d7 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton( kv_indices_ptr, req_to_token_ptr_stride: tl.constexpr, kv_indices_ptr_stride: tl.constexpr, + PAGED_SIZE: tl.constexpr = 64, ): - PAGED_SIZE: tl.constexpr = 64 BLOCK_SIZE: tl.constexpr = 4096 NUM_PAGE_PER_BLOCK: tl.constexpr = 64 pid = tl.program_id(axis=0) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 960b6e70b..314dbbd2e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1515,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) or global_server_args_dict["attention_backend"] == "flashmla" or global_server_args_dict["attention_backend"] == "fa3" + or global_server_args_dict["attention_backend"] == "cutlass_mla" ): seq_lens_cpu = self.seq_lens.cpu() else: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 60ab0f36f..3052e924c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -271,6 +271,7 @@ class ModelRunner: "fa3", "triton", "flashmla", + "cutlass_mla", ]: logger.info( f"MLA optimization is turned on. Use {server_args.attention_backend} backend." @@ -926,6 +927,12 @@ class ModelRunner: ) self.attn_backend = FlashAttentionBackend(self) + elif self.server_args.attention_backend == "cutlass_mla": + from sglang.srt.layers.attention.cutlass_mla_backend import ( + CutlassMLABackend, + ) + + self.attn_backend = CutlassMLABackend(self) else: raise ValueError( f"Invalid attention backend: {self.server_args.attention_backend}" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0371e1c52..fa0d05ed2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -256,6 +256,12 @@ class ServerArgs: ) self.page_size = 64 + if self.attention_backend == "cutlass_mla": + logger.warning( + "Cutlass MLA only supports a page_size of 128, change page_size to 128." + ) + self.page_size = 128 + # Set cuda graph max batch size if self.cuda_graph_max_bs is None: # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. @@ -823,7 +829,14 @@ class ServerArgs: parser.add_argument( "--attention-backend", type=str, - choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"], + choices=[ + "flashinfer", + "triton", + "torch_native", + "fa3", + "flashmla", + "cutlass_mla", + ], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index 52b2d62af..749462ccd 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -78,6 +78,7 @@ def cutlass_mla_decode( assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape assert B_block_table == B_q + assert block_num > 0, f"block num must be greater than 0, got {block_num}" assert block_num % (128 / PAGE_SIZE) == 0 # TODO(kaixih@nvidia): support fp8 @@ -109,6 +110,8 @@ def cutlass_mla_decode( def cutlass_mla_get_workspace_size( max_seq_len: int, num_batches: int, sm_count: int = 0 ) -> int: + assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}" + assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}" return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default( max_seq_len, num_batches, sm_count )