From fc91d08a8f0ef6458bc640bee4de7cf06234f098 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 5 Mar 2025 11:20:41 -0800 Subject: [PATCH] [Revision] Add fast decode plan for flashinfer mla (#4012) --- docs/backend/server_arguments.md | 2 +- docs/references/deepseek.md | 2 +- .../srt/layers/attention/base_attn_backend.py | 1 + .../layers/attention/flashinfer_backend.py | 2 + .../attention/flashinfer_mla_backend.py | 148 ++++++++++++++---- .../srt/layers/attention/triton_backend.py | 2 + python/sglang/srt/managers/schedule_batch.py | 9 ++ .../srt/model_executor/cuda_graph_runner.py | 8 + .../srt/model_executor/forward_batch_info.py | 5 + 9 files changed, 145 insertions(+), 34 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 9777b9b4c..10872f0ed 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -184,5 +184,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. +* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden. * `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 058fd6ae0..b79675a1c 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage) +- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index e0e688ce5..8364a82ca 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -45,6 +45,7 @@ class AttentionBackend(ABC): encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index b48fecb26..acec75241 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -353,6 +353,7 @@ class FlashInferAttnBackend(AttentionBackend): encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( @@ -1058,6 +1059,7 @@ class FlashInferMultiStepDraftBackend: encoder_lens=None, forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, + seq_lens_cpu=None, ) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 96f0b4f83..9e81acc6f 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -10,6 +10,7 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html """ from dataclasses import dataclass +from functools import partial from typing import TYPE_CHECKING, Optional, Union import torch @@ -62,6 +63,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): # Parse constants self.max_context_len = model_runner.model_config.context_len + self.device = model_runner.device global_config.enable_flashinfer_mla = True @@ -84,10 +86,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) - self.kv_last_page_len = torch.ones( - (max_bs,), dtype=torch.int32, device=model_runner.device - ) - self.q_indptr_decode = torch.arange( 0, max_bs + 1, dtype=torch.int32, device=model_runner.device ) @@ -125,6 +123,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): forward_batch.seq_lens, forward_batch.seq_lens_sum, decode_wrapper=self.decode_wrapper, + init_metadata_replay=False, ) self.forward_metadata = DecodeMetadata(self.decode_wrapper) else: @@ -160,13 +159,20 @@ class FlashInferMLAAttnBackend(AttentionBackend): cuda_graph_kv_indices = kv_indices_buf self.cuda_graph_kv_indices = cuda_graph_kv_indices - self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.uint8, - device="cuda", + self.cuda_graph_qo_indptr = self.q_indptr_decode.clone() + self.cuda_graph_kv_indptr = self.kv_indptr.clone() + self.cuda_graph_kv_lens = torch.ones( + (max_bs,), dtype=torch.int32, device=self.device ) - self.cuda_graph_qk_indptr = self.kv_indptr.clone() - self.cuda_graph_qo_indptr = self.kv_indptr.clone() + + # For fast decode plan in graph replaying + self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu") + self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu") + self.fast_decode_kwargs = { + "qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu, + "kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu, + "kv_indices": self.cuda_graph_kv_indices, + } def init_forward_metadata_capture_cuda_graph( self, @@ -182,10 +188,10 @@ class FlashInferMLAAttnBackend(AttentionBackend): decode_wrapper = BatchMLAPagedAttentionWrapper( self.workspace_buffer, use_cuda_graph=True, - qo_indptr=self.qo_indptr[: num_tokens + 1], - kv_indptr=self.kv_indptr[: num_tokens + 1], + qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1], + kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1], kv_indices=self.cuda_graph_kv_indices, - kv_len_arr=self.kv_last_page_len[:num_tokens], + kv_len_arr=self.cuda_graph_kv_lens[:num_tokens], backend="auto", ) @@ -195,9 +201,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): seq_lens, seq_lens_sum, decode_wrapper=decode_wrapper, + init_metadata_replay=False, ) self.decode_cuda_graph_metadata[bs] = decode_wrapper self.forward_metadata = DecodeMetadata(decode_wrapper) + decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -210,13 +218,28 @@ class FlashInferMLAAttnBackend(AttentionBackend): encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): + kv_len_arr_cpu = seq_lens_cpu[:bs] + self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum( + kv_len_arr_cpu, dim=0 + ) + self.fast_decode_kwargs.update( + { + "qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1], + "kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1], + "kv_len_arr_cpu": kv_len_arr_cpu, + } + ) + self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], seq_lens_sum, decode_wrapper=self.decode_cuda_graph_metadata[bs], + init_metadata_replay=True, + **self.fast_decode_kwargs, ) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") @@ -316,7 +339,6 @@ class FlashInferMLAIndicesUpdaterDecode: # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr - self.kv_last_page_len = attn_backend.kv_last_page_len self.req_to_token = model_runner.req_to_token_pool.req_to_token self.q_indptr = attn_backend.q_indptr_decode @@ -326,6 +348,8 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrapper: BatchMLAPagedAttentionWrapper, + init_metadata_replay: bool = False, + **fast_decode_kwargs, ): decode_wrapper = decode_wrapper or self.decode_wrapper self.call_begin_forward( @@ -335,6 +359,8 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens_sum, self.q_indptr, self.kv_indptr, + init_metadata_replay, + **fast_decode_kwargs, ) def call_begin_forward( @@ -345,14 +371,19 @@ class FlashInferMLAIndicesUpdaterDecode: paged_kernel_lens_sum: int, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, + init_metadata_replay: bool = False, + **fast_decode_kwargs, ): bs = len(req_pool_indices) q_indptr = q_indptr[: bs + 1] kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + kv_indices = ( + torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") + if not init_metadata_replay + else fast_decode_kwargs["kv_indices"] ) + kv_lens = paged_kernel_lens.to(torch.int32) sm_scale = self.scaling @@ -365,21 +396,36 @@ class FlashInferMLAIndicesUpdaterDecode: kv_indices, self.req_to_token.shape[1], ) - - wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_lens, - self.num_local_heads, - self.kv_lora_rank, - self.qk_rope_head_dim, - 1, - False, - sm_scale, - self.data_type, - self.data_type, - ) + if not init_metadata_replay: + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + self.num_local_heads, + self.kv_lora_rank, + self.qk_rope_head_dim, + 1, + False, + sm_scale, + self.data_type, + self.data_type, + ) + else: + wrapper.plan( + fast_decode_kwargs["qo_indptr_cpu"], + fast_decode_kwargs["kv_indptr_cpu"], + kv_indices, + fast_decode_kwargs["kv_len_arr_cpu"], + self.num_local_heads, + self.kv_lora_rank, + self.qk_rope_head_dim, + 1, + False, + sm_scale, + self.data_type, + self.data_type, + ) class FlashInferMLAIndicesUpdaterPrefill: @@ -399,7 +445,6 @@ class FlashInferMLAIndicesUpdaterPrefill: # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr - self.kv_last_page_len = attn_backend.kv_last_page_len self.qo_indptr = attn_backend.qo_indptr self.req_to_token = model_runner.req_to_token_pool.req_to_token self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged @@ -496,3 +541,42 @@ class FlashInferMLAIndicesUpdaterPrefill: self.q_data_type, self.data_type, ) + + +def fast_mla_decode_plan( + self, + qo_indptr_cpu: torch.Tensor, + kv_indptr_cpu: torch.Tensor, + kv_indices: torch.Tensor, + kv_len_arr_cpu: torch.Tensor, + num_heads: int, + head_dim_ckv: int, + head_dim_kpe: int, + page_size: int, + causal: bool, + sm_scale: float, + q_data_type: torch.dtype, + kv_data_type: torch.dtype, +) -> None: + """A faster version of BatchMLAPagedAttentionWrapper::plan, + for skipping the stream synchronization in original plan function during + cuda graph replaying. + """ + self._causal = causal + self._page_size = page_size + self._sm_scale = sm_scale + + with self.device as device: + stream = torch.cuda.current_stream(device).cuda_stream + self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_cpu, + kv_indptr_cpu, + kv_len_arr_cpu, + num_heads, + head_dim_ckv, + causal, + stream, + ) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index a8cf3abb8..d07baf102 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -312,6 +312,7 @@ class TritonAttnBackend(AttentionBackend): encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], ): # NOTE: encoder_lens expected to be zeros or None if forward_mode.is_decode_or_idle(): @@ -587,6 +588,7 @@ class TritonMultiStepDraftBackend: encoder_lens=None, forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, + seq_lens_cpu=None, ) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3f09915ba..5cd2de6b8 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1187,8 +1187,13 @@ class ScheduleBatch: def get_model_worker_batch(self) -> ModelWorkerBatch: if self.forward_mode.is_decode_or_idle(): + if global_server_args_dict["enable_flashinfer_mla"]: + decode_seq_lens = self.seq_lens.cpu() + else: + decode_seq_lens = None extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: + decode_seq_lens = None extend_seq_lens = self.extend_lens extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens @@ -1215,6 +1220,7 @@ class ScheduleBatch: global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, + decode_seq_lens=decode_seq_lens, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, @@ -1291,6 +1297,9 @@ class ModelWorkerBatch: global_num_tokens_for_logprob: Optional[List[int]] can_run_dp_cuda_graph: bool + # For decode + decode_seq_lens: Optional[torch.Tensor] + # For extend extend_num_tokens: Optional[int] extend_seq_lens: Optional[List[int]] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 31dc258dd..5bf4faa21 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -200,6 +200,9 @@ class CudaGraphRunner: ) # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 + self.seq_lens_cpu = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) if self.enable_torch_compile: set_torch_compile_config() @@ -448,6 +451,10 @@ class CudaGraphRunner: self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) + if forward_batch.decode_seq_lens_cpu is not None: + if bs != raw_bs: + self.seq_lens_cpu.fill_(1) + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) @@ -466,6 +473,7 @@ class CudaGraphRunner: self.encoder_lens, forward_batch.forward_mode, forward_batch.spec_info, + seq_lens_cpu=self.seq_lens_cpu, ) # Replay diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 7f09cb529..826fd306b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -156,6 +156,9 @@ class ForwardBatch: # Position information positions: torch.Tensor = None + # For decode + decode_seq_lens_cpu: Optional[torch.Tensor] = None + # For extend extend_num_tokens: Optional[int] = None extend_seq_lens: Optional[torch.Tensor] = None @@ -280,6 +283,8 @@ class ForwardBatch: if ret.forward_mode.is_decode(): if ret.positions is None: ret.positions = clamp_position(batch.seq_lens) + if ret.decode_seq_lens_cpu is None: + ret.decode_seq_lens_cpu = batch.decode_seq_lens else: ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32