From b6944f97a616e596f4195ba28f8e0e72b82d25dd Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Wed, 19 Mar 2025 23:25:34 +0800 Subject: [PATCH] Support FlashMLA backend cuda graph (#4514) Co-authored-by: yinfan98 <1106310035@qq.com> Co-authored-by: Hongbosherlock Co-authored-by: ispobock --- .../srt/layers/attention/flashmla_backend.py | 214 +++++++++++++++--- python/sglang/srt/layers/attention/utils.py | 1 - python/sglang/srt/server_args.py | 5 +- 3 files changed, 188 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index 61a67a8d8..1e711e647 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -1,16 +1,13 @@ from __future__ import annotations """ -Support attention backend for flashMLA. +Support attention backend for FlashMLA. -Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking. #TODO -Support FlashMLA decode with cudagraph Enable speculative sampling in FlashMLA -Integrate FA3 prefill """ - +from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union import torch @@ -28,10 +25,30 @@ if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + from sglang.srt.speculative.spec_info import SpecInfo # FlashMLA only supports pagesize=64 PAGE_SIZE = 64 +# TODO The current setup is hard-coded and will be changed after integrating with MTP. +Q_LEN = 1 + + +@dataclass +class FlashMLADecodeMetadata: + flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + num_splits: Optional[torch.Tensor] = None + block_kv_indices: Optional[torch.Tensor] = None + + def __init__( + self, + flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + num_splits: Optional[torch.Tensor] = None, + block_kv_indices: Optional[torch.Tensor] = None, + ): + self.flashmla_metadata = flashmla_metadata + self.num_splits = num_splits + self.block_kv_indices = block_kv_indices class FlashMLABackend(FlashInferMLAAttnBackend): @@ -58,6 +75,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): self.num_local_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) + self.forward_metadata: Union[FlashMLADecodeMetadata] = None self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim @@ -67,6 +85,163 @@ class FlashMLABackend(FlashInferMLAAttnBackend): self.q_data_type = model_runner.dtype self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim + def init_forward_metadata(self, forward_batch: ForwardBatch): + + bs = forward_batch.batch_size + spec_info = forward_batch.spec_info + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + max_seqlen_pad = triton.cdiv( + forward_batch.seq_lens.max().item(), PAGE_SIZE + ) + block_kv_indices = torch.full( + (bs, max_seqlen_pad), + -1, + dtype=torch.int32, + device=forward_batch.seq_lens.device, + ) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + ) + mla_metadata, num_splits = get_mla_metadata( + forward_batch.seq_lens.to(torch.int32), + Q_LEN * self.num_q_heads // self.num_kv_heads, + self.num_kv_heads, + ) + self.forward_metadata = FlashMLADecodeMetadata( + mla_metadata, + num_splits, + block_kv_indices, + ) + else: + super().init_forward_metadata(forward_batch) + else: + super().init_forward_metadata(forward_batch) + + def init_cuda_graph_state( + self, + max_bs: int, + block_kv_indices: Optional[torch.Tensor] = None, + ): + if block_kv_indices is None: + cuda_graph_kv_indices = torch.full( + (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE), + 1, + dtype=torch.int32, + device="cuda", + ) + else: + cuda_graph_kv_indices = block_kv_indices + + self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata( + torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device), + Q_LEN * self.num_q_heads // self.num_kv_heads, + self.num_kv_heads, + ) + self.cuda_graph_kv_indices = cuda_graph_kv_indices + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + ): + if forward_mode.is_decode_or_idle(): + if spec_info is None: + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), + Q_LEN * self.num_q_heads // self.num_kv_heads, + self.num_kv_heads, + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata = FlashMLADecodeMetadata( + self.cuda_graph_mla_metadata, + self.cuda_graph_num_splits[: bs + 1], + self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], + ) + + else: + super().init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], + ): + + if forward_mode.is_decode_or_idle(): + seq_lens = seq_lens[:bs] + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), + Q_LEN * self.num_q_heads // self.num_kv_heads, + self.num_kv_heads, + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata + self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1] + self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ + :bs, :max_seqlen_pad + ] + + else: + super().init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + def forward_decode( self, q: torch.Tensor, @@ -88,39 +263,18 @@ class FlashMLABackend(FlashInferMLAAttnBackend): v, ) bs = forward_batch.batch_size - - max_seqlen_pad = triton.cdiv(forward_batch.seq_lens.max().item(), PAGE_SIZE) - flashmla_index = torch.full( - (bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device - ) - create_flashmla_kv_indices_triton[(bs,)]( - self.indices_updater_decode.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - None, - flashmla_index, - self.indices_updater_decode.req_to_token.size(1), - flashmla_index.size(1), - max_seqlen_pad, - ) - - mla_metadata, mla_splits = get_mla_metadata( - forward_batch.seq_lens.to(torch.int32), - 1 * self.num_q_heads // self.num_kv_heads, - self.num_kv_heads, - ) - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) o, _ = flash_mla_with_kvcache( q=reshape_q, k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), - block_table=flashmla_index, + block_table=self.forward_metadata.block_kv_indices, cache_seqlens=forward_batch.seq_lens.to(torch.int32), head_dim_v=self.kv_lora_rank, # TODO Retrieve from config. - tile_scheduler_metadata=mla_metadata, - num_splits=mla_splits, + tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, + num_splits=self.forward_metadata.num_splits, softmax_scale=layer.scaling, causal=False, ) diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 7fdee7bfe..29b64c24b 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -49,7 +49,6 @@ def create_flashmla_kv_indices_triton( kv_indices_ptr, req_to_token_ptr_stride: tl.constexpr, kv_indices_ptr_stride: tl.constexpr, - max_pagesize: tl.constexpr, ): PAGED_SIZE: tl.constexpr = 64 BLOCK_SIZE: tl.constexpr = 4096 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 144b8b6ce..5c1584c8c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -232,7 +232,10 @@ class ServerArgs: assert self.chunked_prefill_size % self.page_size == 0 if self.enable_flashmla is True: - assert self.page_size == 64, "FlashMLA only support page_size=64" + logger.warning( + "FlashMLA only supports a page_size of 64, change page_size to 64." + ) + self.page_size = 64 # Set cuda graph max batch size if self.cuda_graph_max_bs is None: # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.