diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 0b055b101..db8bb9514 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -186,5 +186,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden. +* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. * `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 5f1dbf2e7..ad180d1bd 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. +- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage) - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 23b02342a..745598643 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -29,8 +29,9 @@ class AttentionBackend(ABC): num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - **kwargs, + spec_info: Optional[SpecInfo], ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -41,8 +42,9 @@ class AttentionBackend(ABC): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - **kwargs, + spec_info: Optional[SpecInfo], ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index bbcb011dc..7063b6f4b 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -269,10 +269,9 @@ class FlashInferAttnBackend(AttentionBackend): num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - forward_mode: ForwardMode, encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, spec_info: Optional[SpecInfo], - **kwargs, ): if forward_mode.is_decode_or_idle(): decode_wrappers = [] @@ -340,10 +339,9 @@ class FlashInferAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - forward_mode: ForwardMode, encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, spec_info: Optional[SpecInfo], - **kwargs, ): if forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index f4540ed40..e7088df5c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -10,7 +10,6 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html """ from dataclasses import dataclass -from functools import partial from typing import TYPE_CHECKING, Optional, Union import torch @@ -28,12 +27,14 @@ from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo if is_flashinfer_available(): from flashinfer import ( BatchMLAPagedAttentionWrapper, BatchPrefillWithRaggedKVCacheWrapper, ) + from flashinfer.cascade import merge_state @dataclass @@ -62,7 +63,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): # Parse constants self.max_context_len = model_runner.model_config.context_len - self.device = model_runner.device global_config.enable_flashinfer_mla = True @@ -85,6 +85,10 @@ class FlashInferMLAAttnBackend(AttentionBackend): (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) + self.kv_last_page_len = torch.ones( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.q_indptr_decode = torch.arange( 0, max_bs + 1, dtype=torch.int32, device=model_runner.device ) @@ -122,7 +126,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): forward_batch.seq_lens, forward_batch.seq_lens_sum, decode_wrapper=self.decode_wrapper, - init_metadata_replay=False, ) self.forward_metadata = DecodeMetadata(self.decode_wrapper) else: @@ -158,20 +161,13 @@ class FlashInferMLAAttnBackend(AttentionBackend): cuda_graph_kv_indices = kv_indices_buf self.cuda_graph_kv_indices = cuda_graph_kv_indices - self.cuda_graph_qo_indptr = self.q_indptr_decode.clone() - self.cuda_graph_kv_indptr = self.kv_indptr.clone() - self.cuda_graph_kv_lens = torch.ones( - (max_bs,), dtype=torch.int32, device=self.device + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device="cuda", ) - - # For fast decode plan in graph replaying - self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu") - self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu") - self.fast_decode_kwargs = { - "qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu, - "kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu, - "kv_indices": self.cuda_graph_kv_indices, - } + self.cuda_graph_qk_indptr = self.kv_indptr.clone() + self.cuda_graph_qo_indptr = self.kv_indptr.clone() def init_forward_metadata_capture_cuda_graph( self, @@ -179,17 +175,18 @@ class FlashInferMLAAttnBackend(AttentionBackend): num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - **kwargs, + spec_info: Optional[SpecInfo], ): if forward_mode.is_decode_or_idle(): decode_wrapper = BatchMLAPagedAttentionWrapper( self.workspace_buffer, use_cuda_graph=True, - qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1], - kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1], + qo_indptr=self.qo_indptr[: num_tokens + 1], + kv_indptr=self.kv_indptr[: num_tokens + 1], kv_indices=self.cuda_graph_kv_indices, - kv_len_arr=self.cuda_graph_kv_lens[:num_tokens], + kv_len_arr=self.kv_last_page_len[:num_tokens], backend="auto", ) @@ -199,11 +196,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): seq_lens, seq_lens_sum, decode_wrapper=decode_wrapper, - init_metadata_replay=False, ) self.decode_cuda_graph_metadata[bs] = decode_wrapper self.forward_metadata = DecodeMetadata(decode_wrapper) - decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -213,30 +208,16 @@ class FlashInferMLAAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - seq_lens_cpu: torch.Tensor, - **kwargs, + spec_info: Optional[SpecInfo], ): if forward_mode.is_decode_or_idle(): - kv_len_arr_cpu = seq_lens_cpu[:bs] - self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum( - kv_len_arr_cpu, dim=0 - ) - self.fast_decode_kwargs.update( - { - "qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1], - "kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1], - "kv_len_arr_cpu": kv_len_arr_cpu, - } - ) - self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], seq_lens_sum, decode_wrapper=self.decode_cuda_graph_metadata[bs], - init_metadata_replay=True, - **self.fast_decode_kwargs, ) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") @@ -336,6 +317,7 @@ class FlashInferMLAIndicesUpdaterDecode: # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr + self.kv_last_page_len = attn_backend.kv_last_page_len self.req_to_token = model_runner.req_to_token_pool.req_to_token self.q_indptr = attn_backend.q_indptr_decode @@ -345,8 +327,6 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrapper: BatchMLAPagedAttentionWrapper, - init_metadata_replay: bool = False, - **fast_decode_kwargs, ): decode_wrapper = decode_wrapper or self.decode_wrapper self.call_begin_forward( @@ -356,8 +336,6 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens_sum, self.q_indptr, self.kv_indptr, - init_metadata_replay, - **fast_decode_kwargs, ) def call_begin_forward( @@ -368,19 +346,14 @@ class FlashInferMLAIndicesUpdaterDecode: paged_kernel_lens_sum: int, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, - init_metadata_replay: bool = False, - **fast_decode_kwargs, ): bs = len(req_pool_indices) q_indptr = q_indptr[: bs + 1] kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = ( - torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") - if not init_metadata_replay - else fast_decode_kwargs["kv_indices"] + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) - kv_lens = paged_kernel_lens.to(torch.int32) sm_scale = self.scaling @@ -393,36 +366,21 @@ class FlashInferMLAIndicesUpdaterDecode: kv_indices, self.req_to_token.shape[1], ) - if not init_metadata_replay: - wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_lens, - self.num_local_heads, - self.kv_lora_rank, - self.qk_rope_head_dim, - 1, - False, - sm_scale, - self.data_type, - self.data_type, - ) - else: - wrapper.plan( - fast_decode_kwargs["qo_indptr_cpu"], - fast_decode_kwargs["kv_indptr_cpu"], - kv_indices, - fast_decode_kwargs["kv_len_arr_cpu"], - self.num_local_heads, - self.kv_lora_rank, - self.qk_rope_head_dim, - 1, - False, - sm_scale, - self.data_type, - self.data_type, - ) + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + self.num_local_heads, + self.kv_lora_rank, + self.qk_rope_head_dim, + 1, + False, + sm_scale, + self.data_type, + self.data_type, + ) class FlashInferMLAIndicesUpdaterPrefill: @@ -442,6 +400,7 @@ class FlashInferMLAIndicesUpdaterPrefill: # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr + self.kv_last_page_len = attn_backend.kv_last_page_len self.qo_indptr = attn_backend.qo_indptr self.req_to_token = model_runner.req_to_token_pool.req_to_token self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged @@ -538,42 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill: self.q_data_type, self.data_type, ) - - -def fast_mla_decode_plan( - self, - qo_indptr_cpu: torch.Tensor, - kv_indptr_cpu: torch.Tensor, - kv_indices: torch.Tensor, - kv_len_arr_cpu: torch.Tensor, - num_heads: int, - head_dim_ckv: int, - head_dim_kpe: int, - page_size: int, - causal: bool, - sm_scale: float, - q_data_type: torch.dtype, - kv_data_type: torch.dtype, -) -> None: - """A faster version of BatchMLAPagedAttentionWrapper::plan, - for skipping the stream synchronization in original plan function during - cuda graph replaying. - """ - self._causal = causal - self._page_size = page_size - self._sm_scale = sm_scale - - with self.device as device: - stream = torch.cuda.current_stream(device).cuda_stream - self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_cpu, - kv_indptr_cpu, - kv_len_arr_cpu, - num_heads, - head_dim_ckv, - causal, - stream, - ) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index ce68af89d..7bb6615ed 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -230,10 +230,9 @@ class TritonAttnBackend(AttentionBackend): num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - forward_mode: ForwardMode, encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, spec_info: Optional[SpecInfo], - **kwargs, ): assert encoder_lens is None, "Not supported" @@ -309,10 +308,9 @@ class TritonAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - forward_mode: ForwardMode, encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, spec_info: Optional[SpecInfo], - **kwargs, ): # NOTE: encoder_lens expected to be zeros or None if forward_mode.is_decode_or_idle(): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ffb08fdb5..3d34cefb1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -582,9 +582,6 @@ class ScheduleBatch: return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None - # For decode - decode_seq_lens: List[int] = None - # For extend and mixed chunekd prefill prefix_lens: List[int] = None extend_lens: List[int] = None @@ -1171,10 +1168,8 @@ class ScheduleBatch: def get_model_worker_batch(self): if self.forward_mode.is_decode_or_idle(): - decode_seq_lens = self.seq_lens.cpu() extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: - decode_seq_lens = None extend_seq_lens = self.extend_lens extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens @@ -1199,7 +1194,6 @@ class ScheduleBatch: top_logprobs_nums=self.top_logprobs_nums, global_num_tokens=self.global_num_tokens, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, - decode_seq_lens=decode_seq_lens, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, @@ -1273,9 +1267,6 @@ class ModelWorkerBatch: global_num_tokens: Optional[List[int]] can_run_dp_cuda_graph: bool - # For decode - decode_seq_lens: Optional[torch.Tensor] - # For extend extend_num_tokens: Optional[int] extend_seq_lens: Optional[List[int]] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index fd9280239..dd1b0da94 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -199,10 +199,6 @@ class CudaGraphRunner: if self.enable_torch_compile: set_torch_compile_config() - self.seq_lens_cpu = torch.full( - (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 - ) - # Graph inputs with torch.device("cuda"): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) @@ -373,9 +369,9 @@ class CudaGraphRunner: num_tokens, req_pool_indices, seq_lens, + encoder_lens, forward_batch.forward_mode, - encoder_lens=encoder_lens, - spec_info=forward_batch.spec_info, + forward_batch.spec_info, ) # Run and capture @@ -438,7 +434,6 @@ class CudaGraphRunner: if bs != raw_bs: self.seq_lens.fill_(1) self.out_cache_loc.zero_() - self.seq_lens_cpu.fill_(1) # Common inputs self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) @@ -446,8 +441,6 @@ class CudaGraphRunner: self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) - if forward_batch.decode_seq_lens_cpu is not None: - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) @@ -463,10 +456,9 @@ class CudaGraphRunner: self.req_pool_indices, self.seq_lens, forward_batch.seq_lens_sum + (bs - raw_bs), + self.encoder_lens, forward_batch.forward_mode, - encoder_lens=self.encoder_lens, - spec_info=forward_batch.spec_info, - seq_lens_cpu=self.seq_lens_cpu, + forward_batch.spec_info, ) # Replay diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 92ef9a3fa..cdd03bec4 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -152,9 +152,6 @@ class ForwardBatch: # Position information positions: torch.Tensor = None - # For decode - decode_seq_lens_cpu: Optional[torch.Tensor] = None - # For extend extend_num_tokens: Optional[int] = None extend_seq_lens: Optional[torch.Tensor] = None @@ -259,8 +256,6 @@ class ForwardBatch: if ret.forward_mode.is_decode(): if ret.positions is None: ret.positions = clamp_position(batch.seq_lens) - if ret.decode_seq_lens_cpu is None: - ret.decode_seq_lens_cpu = batch.decode_seq_lens else: ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32