diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index db8bb9514..0b055b101 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -186,5 +186,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. +* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden. * `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index ad180d1bd..5f1dbf2e7 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage) +- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 745598643..23b02342a 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -29,9 +29,8 @@ class AttentionBackend(ABC): num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + **kwargs, ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -42,9 +41,8 @@ class AttentionBackend(ABC): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + **kwargs, ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 7063b6f4b..bbcb011dc 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -269,9 +269,10 @@ class FlashInferAttnBackend(AttentionBackend): num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, + encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], + **kwargs, ): if forward_mode.is_decode_or_idle(): decode_wrappers = [] @@ -339,9 +340,10 @@ class FlashInferAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, + encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], + **kwargs, ): if forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index e7088df5c..f4540ed40 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -10,6 +10,7 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html """ from dataclasses import dataclass +from functools import partial from typing import TYPE_CHECKING, Optional, Union import torch @@ -27,14 +28,12 @@ from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo if is_flashinfer_available(): from flashinfer import ( BatchMLAPagedAttentionWrapper, BatchPrefillWithRaggedKVCacheWrapper, ) - from flashinfer.cascade import merge_state @dataclass @@ -63,6 +62,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): # Parse constants self.max_context_len = model_runner.model_config.context_len + self.device = model_runner.device global_config.enable_flashinfer_mla = True @@ -85,10 +85,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) - self.kv_last_page_len = torch.ones( - (max_bs,), dtype=torch.int32, device=model_runner.device - ) - self.q_indptr_decode = torch.arange( 0, max_bs + 1, dtype=torch.int32, device=model_runner.device ) @@ -126,6 +122,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): forward_batch.seq_lens, forward_batch.seq_lens_sum, decode_wrapper=self.decode_wrapper, + init_metadata_replay=False, ) self.forward_metadata = DecodeMetadata(self.decode_wrapper) else: @@ -161,13 +158,20 @@ class FlashInferMLAAttnBackend(AttentionBackend): cuda_graph_kv_indices = kv_indices_buf self.cuda_graph_kv_indices = cuda_graph_kv_indices - self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.uint8, - device="cuda", + self.cuda_graph_qo_indptr = self.q_indptr_decode.clone() + self.cuda_graph_kv_indptr = self.kv_indptr.clone() + self.cuda_graph_kv_lens = torch.ones( + (max_bs,), dtype=torch.int32, device=self.device ) - self.cuda_graph_qk_indptr = self.kv_indptr.clone() - self.cuda_graph_qo_indptr = self.kv_indptr.clone() + + # For fast decode plan in graph replaying + self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu") + self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu") + self.fast_decode_kwargs = { + "qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu, + "kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu, + "kv_indices": self.cuda_graph_kv_indices, + } def init_forward_metadata_capture_cuda_graph( self, @@ -175,18 +179,17 @@ class FlashInferMLAAttnBackend(AttentionBackend): num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + **kwargs, ): if forward_mode.is_decode_or_idle(): decode_wrapper = BatchMLAPagedAttentionWrapper( self.workspace_buffer, use_cuda_graph=True, - qo_indptr=self.qo_indptr[: num_tokens + 1], - kv_indptr=self.kv_indptr[: num_tokens + 1], + qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1], + kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1], kv_indices=self.cuda_graph_kv_indices, - kv_len_arr=self.kv_last_page_len[:num_tokens], + kv_len_arr=self.cuda_graph_kv_lens[:num_tokens], backend="auto", ) @@ -196,9 +199,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): seq_lens, seq_lens_sum, decode_wrapper=decode_wrapper, + init_metadata_replay=False, ) self.decode_cuda_graph_metadata[bs] = decode_wrapper self.forward_metadata = DecodeMetadata(decode_wrapper) + decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -208,16 +213,30 @@ class FlashInferMLAAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + seq_lens_cpu: torch.Tensor, + **kwargs, ): if forward_mode.is_decode_or_idle(): + kv_len_arr_cpu = seq_lens_cpu[:bs] + self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum( + kv_len_arr_cpu, dim=0 + ) + self.fast_decode_kwargs.update( + { + "qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1], + "kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1], + "kv_len_arr_cpu": kv_len_arr_cpu, + } + ) + self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], seq_lens_sum, decode_wrapper=self.decode_cuda_graph_metadata[bs], + init_metadata_replay=True, + **self.fast_decode_kwargs, ) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") @@ -317,7 +336,6 @@ class FlashInferMLAIndicesUpdaterDecode: # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr - self.kv_last_page_len = attn_backend.kv_last_page_len self.req_to_token = model_runner.req_to_token_pool.req_to_token self.q_indptr = attn_backend.q_indptr_decode @@ -327,6 +345,8 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrapper: BatchMLAPagedAttentionWrapper, + init_metadata_replay: bool = False, + **fast_decode_kwargs, ): decode_wrapper = decode_wrapper or self.decode_wrapper self.call_begin_forward( @@ -336,6 +356,8 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens_sum, self.q_indptr, self.kv_indptr, + init_metadata_replay, + **fast_decode_kwargs, ) def call_begin_forward( @@ -346,14 +368,19 @@ class FlashInferMLAIndicesUpdaterDecode: paged_kernel_lens_sum: int, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, + init_metadata_replay: bool = False, + **fast_decode_kwargs, ): bs = len(req_pool_indices) q_indptr = q_indptr[: bs + 1] kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + kv_indices = ( + torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") + if not init_metadata_replay + else fast_decode_kwargs["kv_indices"] ) + kv_lens = paged_kernel_lens.to(torch.int32) sm_scale = self.scaling @@ -366,21 +393,36 @@ class FlashInferMLAIndicesUpdaterDecode: kv_indices, self.req_to_token.shape[1], ) - - wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_lens, - self.num_local_heads, - self.kv_lora_rank, - self.qk_rope_head_dim, - 1, - False, - sm_scale, - self.data_type, - self.data_type, - ) + if not init_metadata_replay: + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + self.num_local_heads, + self.kv_lora_rank, + self.qk_rope_head_dim, + 1, + False, + sm_scale, + self.data_type, + self.data_type, + ) + else: + wrapper.plan( + fast_decode_kwargs["qo_indptr_cpu"], + fast_decode_kwargs["kv_indptr_cpu"], + kv_indices, + fast_decode_kwargs["kv_len_arr_cpu"], + self.num_local_heads, + self.kv_lora_rank, + self.qk_rope_head_dim, + 1, + False, + sm_scale, + self.data_type, + self.data_type, + ) class FlashInferMLAIndicesUpdaterPrefill: @@ -400,7 +442,6 @@ class FlashInferMLAIndicesUpdaterPrefill: # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr - self.kv_last_page_len = attn_backend.kv_last_page_len self.qo_indptr = attn_backend.qo_indptr self.req_to_token = model_runner.req_to_token_pool.req_to_token self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged @@ -497,3 +538,42 @@ class FlashInferMLAIndicesUpdaterPrefill: self.q_data_type, self.data_type, ) + + +def fast_mla_decode_plan( + self, + qo_indptr_cpu: torch.Tensor, + kv_indptr_cpu: torch.Tensor, + kv_indices: torch.Tensor, + kv_len_arr_cpu: torch.Tensor, + num_heads: int, + head_dim_ckv: int, + head_dim_kpe: int, + page_size: int, + causal: bool, + sm_scale: float, + q_data_type: torch.dtype, + kv_data_type: torch.dtype, +) -> None: + """A faster version of BatchMLAPagedAttentionWrapper::plan, + for skipping the stream synchronization in original plan function during + cuda graph replaying. + """ + self._causal = causal + self._page_size = page_size + self._sm_scale = sm_scale + + with self.device as device: + stream = torch.cuda.current_stream(device).cuda_stream + self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_cpu, + kv_indptr_cpu, + kv_len_arr_cpu, + num_heads, + head_dim_ckv, + causal, + stream, + ) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 7bb6615ed..ce68af89d 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -230,9 +230,10 @@ class TritonAttnBackend(AttentionBackend): num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, + encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], + **kwargs, ): assert encoder_lens is None, "Not supported" @@ -308,9 +309,10 @@ class TritonAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, + encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], + **kwargs, ): # NOTE: encoder_lens expected to be zeros or None if forward_mode.is_decode_or_idle(): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3d34cefb1..ffb08fdb5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -582,6 +582,9 @@ class ScheduleBatch: return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None + # For decode + decode_seq_lens: List[int] = None + # For extend and mixed chunekd prefill prefix_lens: List[int] = None extend_lens: List[int] = None @@ -1168,8 +1171,10 @@ class ScheduleBatch: def get_model_worker_batch(self): if self.forward_mode.is_decode_or_idle(): + decode_seq_lens = self.seq_lens.cpu() extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: + decode_seq_lens = None extend_seq_lens = self.extend_lens extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens @@ -1194,6 +1199,7 @@ class ScheduleBatch: top_logprobs_nums=self.top_logprobs_nums, global_num_tokens=self.global_num_tokens, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, + decode_seq_lens=decode_seq_lens, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, @@ -1267,6 +1273,9 @@ class ModelWorkerBatch: global_num_tokens: Optional[List[int]] can_run_dp_cuda_graph: bool + # For decode + decode_seq_lens: Optional[torch.Tensor] + # For extend extend_num_tokens: Optional[int] extend_seq_lens: Optional[List[int]] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index dd1b0da94..fd9280239 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -199,6 +199,10 @@ class CudaGraphRunner: if self.enable_torch_compile: set_torch_compile_config() + self.seq_lens_cpu = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + # Graph inputs with torch.device("cuda"): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) @@ -369,9 +373,9 @@ class CudaGraphRunner: num_tokens, req_pool_indices, seq_lens, - encoder_lens, forward_batch.forward_mode, - forward_batch.spec_info, + encoder_lens=encoder_lens, + spec_info=forward_batch.spec_info, ) # Run and capture @@ -434,6 +438,7 @@ class CudaGraphRunner: if bs != raw_bs: self.seq_lens.fill_(1) self.out_cache_loc.zero_() + self.seq_lens_cpu.fill_(1) # Common inputs self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) @@ -441,6 +446,8 @@ class CudaGraphRunner: self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) + if forward_batch.decode_seq_lens_cpu is not None: + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) @@ -456,9 +463,10 @@ class CudaGraphRunner: self.req_pool_indices, self.seq_lens, forward_batch.seq_lens_sum + (bs - raw_bs), - self.encoder_lens, forward_batch.forward_mode, - forward_batch.spec_info, + encoder_lens=self.encoder_lens, + spec_info=forward_batch.spec_info, + seq_lens_cpu=self.seq_lens_cpu, ) # Replay diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index cdd03bec4..92ef9a3fa 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -152,6 +152,9 @@ class ForwardBatch: # Position information positions: torch.Tensor = None + # For decode + decode_seq_lens_cpu: Optional[torch.Tensor] = None + # For extend extend_num_tokens: Optional[int] = None extend_seq_lens: Optional[torch.Tensor] = None @@ -256,6 +259,8 @@ class ForwardBatch: if ret.forward_mode.is_decode(): if ret.positions is None: ret.positions = clamp_position(batch.seq_lens) + if ret.decode_seq_lens_cpu is None: + ret.decode_seq_lens_cpu = batch.decode_seq_lens else: ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32