From 90a4b7d98a5c2d7b9d3dcc7dad6809a6d3d9ca8f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 28 Feb 2025 18:13:56 -0800 Subject: [PATCH] [Feature]Support ragged prefill in flashinfer mla backend (#3967) Co-authored-by: Yineng Zhang Co-authored-by: pankajroark --- docs/backend/server_arguments.md | 3 +- .../layers/attention/flashinfer_backend.py | 454 +++++------------- .../attention/flashinfer_mla_backend.py | 140 +++--- python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/models/deepseek_v2.py | 5 +- python/sglang/srt/server_args.py | 6 + test/srt/run_suite.py | 1 + test/srt/test_mla_flashinfer.py | 104 ++++ 9 files changed, 308 insertions(+), 407 deletions(-) create mode 100644 test/srt/test_mla_flashinfer.py diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7879ada57..7a614b61a 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -133,7 +133,6 @@ Please consult the documentation below to learn more about the parameters you ma * `attention_backend`: The backend for attention computation and KV cache management. * `sampling_backend`: The backend for sampling. -* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. (In Experiment Stage) ## Constrained Decoding @@ -186,3 +185,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. +* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. +* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on. diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index e39bdd2d6..7063b6f4b 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -37,7 +37,6 @@ if is_flashinfer_available(): BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state - from flashinfer.mla import BatchMLAPagedAttentionWrapper class WrapperDispatch(Enum): @@ -47,16 +46,12 @@ class WrapperDispatch(Enum): @dataclass class DecodeMetadata: - decode_wrappers: List[ - Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] - ] + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] @dataclass class PrefillMetadata: - prefill_wrappers: List[ - Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] - ] + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper] use_ragged: bool extend_no_prefix: bool @@ -109,12 +104,6 @@ class FlashInferAttnBackend(AttentionBackend): if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: global_config.flashinfer_workspace_size = 512 * 1024 * 1024 - self.enable_flashinfer_mla = False - if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures: - if global_server_args_dict["enable_flashinfer_mla"]: - self.enable_flashinfer_mla = True - global_config.enable_flashinfer_mla = True - # Allocate buffers global global_workspace_buffer if global_workspace_buffer is None: @@ -132,13 +121,6 @@ class FlashInferAttnBackend(AttentionBackend): ) for _ in range(self.num_wrappers) ] - if self.enable_flashinfer_mla: - self.qo_indptr = [ - torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) - for _ in range(self.num_wrappers) - ] else: assert self.num_wrappers == 1 self.kv_indptr = [kv_indptr_buf] @@ -162,48 +144,24 @@ class FlashInferAttnBackend(AttentionBackend): self.decode_wrappers = [] for _ in range(self.num_wrappers): if not skip_prefill: - if ( - self.enable_flashinfer_mla - and not global_server_args_dict["disable_radix_cache"] - ): - # use mla paged prefill - self.prefill_wrappers_paged.append( - BatchMLAPagedAttentionWrapper( - self.workspace_buffer, - backend="fa2", - ) - ) - self.prefill_wrappers_verify.append( - BatchMLAPagedAttentionWrapper( - self.workspace_buffer, - backend="fa2", - ) - ) - else: - self.prefill_wrappers_paged.append( - BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - backend="fa2", - ) - ) - self.prefill_wrappers_verify.append( - BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD" - ) - ) - if self.enable_flashinfer_mla: - self.decode_wrappers.append( - BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2") - ) - else: - self.decode_wrappers.append( - BatchDecodeWithPagedKVCacheWrapper( + self.prefill_wrappers_paged.append( + BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, "NHD", - use_tensor_cores=self.decode_use_tensor_cores, + backend="fa2", ) ) + self.prefill_wrappers_verify.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) + + self.decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, + ) + ) # Create indices updater if not skip_prefill: @@ -259,10 +217,7 @@ class FlashInferAttnBackend(AttentionBackend): else: prefix_lens = forward_batch.extend_prefix_lens - if self.is_multimodal or ( - self.enable_flashinfer_mla - and not global_server_args_dict["disable_radix_cache"] - ): + if self.is_multimodal: use_ragged = False extend_no_prefix = False else: @@ -321,32 +276,20 @@ class FlashInferAttnBackend(AttentionBackend): if forward_mode.is_decode_or_idle(): decode_wrappers = [] for i in range(self.num_wrappers): - if self.enable_flashinfer_mla: - decode_wrappers.append( - BatchMLAPagedAttentionWrapper( - self.workspace_buffer, - use_cuda_graph=True, - qo_indptr=self.qo_indptr[i][: num_tokens + 1], - kv_indptr=self.kv_indptr[i][: num_tokens + 1], - kv_indices=self.cuda_graph_kv_indices[i], - kv_len_arr=self.kv_last_page_len[:num_tokens], - backend="fa2", - ) - ) - else: - decode_wrappers.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], - paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.kv_last_page_len[ - :num_tokens - ], - ) + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.kv_last_page_len[ + :num_tokens + ], ) + ) + seq_lens_sum = seq_lens.sum().item() self.indices_updater_decode.update( req_pool_indices, @@ -435,114 +378,64 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - if global_config.enable_flashinfer_mla: - cache_loc = ( - forward_batch.out_cache_loc - if not layer.is_cross_attention - else forward_batch.encoder_out_cache_loc - ) + prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ + self._get_wrapper_idx(layer) + ] + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) - logits_soft_cap = layer.logit_cap - - if global_server_args_dict["disable_radix_cache"]: - # use mla ragged prefill - o, _ = self.prefill_wrapper_ragged.forward_return_lse( - q.view(-1, layer.tp_q_head_num, layer.head_dim), - k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_v_head_num, layer.v_head_dim), - causal=True, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - ) - - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - cache_loc, - k, - v, - ) - else: - # use mla paged prefill - prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ - self._get_wrapper_idx(layer) - ] - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v - ) - qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) - k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - - o = prefill_wrapper_paged.run( - qall[:, :, : layer.v_head_dim], - qall[:, :, layer.v_head_dim :], - k_buf[:, :, : layer.v_head_dim], - k_buf[:, :, layer.v_head_dim :], - ) - - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) - else: - prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ - self._get_wrapper_idx(layer) - ] - cache_loc = ( - forward_batch.out_cache_loc - if not layer.is_cross_attention - else forward_batch.encoder_out_cache_loc - ) - - logits_soft_cap = layer.logit_cap - - if not self.forward_metadata.use_ragged: - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale - ) - - o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, - sm_scale=layer.scaling, - window_left=layer.sliding_window_size, - logits_soft_cap=logits_soft_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, - ) - else: - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( - q.view(-1, layer.tp_q_head_num, layer.head_dim), - k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_v_head_num, layer.head_dim), - causal=True, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - ) - - if self.forward_metadata.extend_no_prefix: - o = o1 - else: - o2, s2 = prefill_wrapper_paged.forward_return_lse( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=False, - sm_scale=layer.scaling, - logits_soft_cap=layer.logit_cap, - ) - - o, _ = merge_state(o1, s1, o2, s2) + logits_soft_cap = layer.logit_cap + if not self.forward_metadata.use_ragged: + if k is not None: + assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + o = prefill_wrapper_paged.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=not layer.is_cross_attention, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=logits_soft_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + ) + else: + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim), + v.view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) + + if self.forward_metadata.extend_no_prefix: + o = o1 + else: + o2, s2 = prefill_wrapper_paged.forward_return_lse( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=False, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + o, _ = merge_state(o1, s1, o2, s2) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode( self, @@ -562,45 +455,23 @@ class FlashInferAttnBackend(AttentionBackend): else forward_batch.encoder_out_cache_loc ) - if self.enable_flashinfer_mla: - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - cache_loc, - k, - v, - ) - reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - reshaped_k = k_buffer.view(-1, 1, layer.head_dim) - o = decode_wrapper.run( - reshaped_q[:, :, : layer.v_head_dim], - reshaped_q[:, :, layer.v_head_dim :], - reshaped_k[:, :, : layer.v_head_dim], - reshaped_k[:, :, layer.v_head_dim :], - ) + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) - else: - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale - ) + o = decode_wrapper.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + ) - o = decode_wrapper.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - sm_scale=layer.scaling, - logits_soft_cap=layer.logit_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, - ) - - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + return o.view(-1, layer.tp_q_head_num * layer.head_dim) def _get_wrapper_idx(self, layer: RadixAttention): if self.num_wrappers == 1: @@ -648,9 +519,7 @@ class FlashInferIndicesUpdaterDecode: req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List[ - Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] - ], + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], ): @@ -662,9 +531,7 @@ class FlashInferIndicesUpdaterDecode: req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List[ - Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] - ], + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], ): @@ -745,9 +612,7 @@ class FlashInferIndicesUpdaterDecode: def call_begin_forward( self, - wrapper: Union[ - BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper - ], + wrapper: BatchDecodeWithPagedKVCacheWrapper, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, @@ -775,37 +640,18 @@ class FlashInferIndicesUpdaterDecode: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 - if global_config.enable_flashinfer_mla: - sm_scale = 1.0 / math.sqrt(192) - q_indptr = torch.arange(0, bs + 1).to(0).int() - kv_lens = paged_kernel_lens.to(torch.int32) - wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_lens, - self.num_qo_heads, - 512, - 64, - 1, - False, - sm_scale, - self.data_type, - self.data_type, - ) - else: - wrapper.begin_forward( - kv_indptr, - kv_indices, - self.kv_last_page_len[:bs], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - 1, - data_type=self.data_type, - q_data_type=self.q_data_type, - non_blocking=True, - ) + wrapper.begin_forward( + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + non_blocking=True, + ) class FlashInferIndicesUpdaterPrefill: @@ -845,9 +691,7 @@ class FlashInferIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, - prefill_wrappers: List[ - Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] - ], + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], @@ -861,9 +705,7 @@ class FlashInferIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, - prefill_wrappers: List[ - Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] - ], + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], @@ -972,9 +814,7 @@ class FlashInferIndicesUpdaterPrefill: def call_begin_forward( self, wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, - wrapper_paged: Union[ - BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper - ], + wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, @@ -1020,62 +860,30 @@ class FlashInferIndicesUpdaterPrefill: # extend part if use_ragged: - if global_config.enable_flashinfer_mla: - wrapper_ragged.begin_forward( - qo_indptr=qo_indptr, - kv_indptr=qo_indptr, - num_qo_heads=self.num_qo_heads, - num_kv_heads=self.num_kv_heads, - head_dim_qk=192, - head_dim_vo=128, - q_data_type=self.q_data_type, - ) - else: - wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - q_data_type=self.q_data_type, - ) - - if not global_config.enable_flashinfer_mla: - # cached part - wrapper_paged.begin_forward( + wrapper_ragged.begin_forward( + qo_indptr, qo_indptr, - kv_indptr, - kv_indices, - self.kv_last_page_len[:bs], self.num_qo_heads, self.num_kv_heads, self.head_dim, - 1, q_data_type=self.q_data_type, - custom_mask=custom_mask, - non_blocking=True, - ) - elif ( - global_config.enable_flashinfer_mla - and not global_server_args_dict["disable_radix_cache"] - ): - # mla paged prefill - kv_len_arr = kv_indptr[1:] - kv_indptr[:-1] - wrapper_paged.plan( - qo_indptr, - kv_indptr, - kv_indices, - kv_len_arr, - self.num_qo_heads, - 512, - 64, - 1, - True, - 1 / math.sqrt(192), - self.data_type, - self.data_type, ) + # cached part + wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + q_data_type=self.q_data_type, + custom_mask=custom_mask, + non_blocking=True, + ) + class FlashInferMultiStepDraftBackend: """ diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index a0b18c422..e7088df5c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -2,13 +2,13 @@ from __future__ import annotations """ Support attention backend for flashinfer MLA. -When radix cache is enabled, the backend only uses BatchMLAPaged wrapper when forwarding. -When radix cache is disabled, the backend uses BatchPrefill wrappers for prefilling (with or without prefix cache), +The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false. +When it's set to false, all wrappers are BatchMLAPaged wrapper. +When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling, and uses BatchMLAPaged wrapper for decoding. More details can be found in https://docs.flashinfer.ai/api/mla.html """ -import math from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union @@ -18,7 +18,6 @@ from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, - should_use_tensor_core, ) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -32,11 +31,10 @@ if TYPE_CHECKING: if is_flashinfer_available(): from flashinfer import ( - BatchPrefillWithPagedKVCacheWrapper, + BatchMLAPagedAttentionWrapper, BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state - from flashinfer.mla import BatchMLAPagedAttentionWrapper @dataclass @@ -46,9 +44,7 @@ class DecodeMetadata: @dataclass class PrefillMetadata: - prefill_wrapper: Union[ - BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper - ] + prefill_wrapper: BatchMLAPagedAttentionWrapper use_ragged: bool @@ -62,7 +58,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): def __init__( self, model_runner: ModelRunner, - kv_indptr_buf: Optional[torch.Tensor] = None, ): super().__init__() @@ -82,12 +77,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.workspace_buffer = global_workspace_buffer max_bs = model_runner.req_to_token_pool.size - if kv_indptr_buf is None: - self.kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) - else: - self.kv_indptr = kv_indptr_buf + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) self.qo_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device @@ -97,22 +89,19 @@ class FlashInferMLAAttnBackend(AttentionBackend): (max_bs,), dtype=torch.int32, device=model_runner.device ) + self.q_indptr_decode = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.workspace_buffer, "NHD" ) - if not global_server_args_dict["disable_radix_cache"]: - # use mla paged prefill - self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper( - self.workspace_buffer, - backend="auto", - ) - else: - self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - backend="auto", - ) + self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + backend="auto", + ) + self.decode_wrapper = BatchMLAPagedAttentionWrapper( self.workspace_buffer, backend="auto" ) @@ -141,7 +130,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.forward_metadata = DecodeMetadata(self.decode_wrapper) else: prefix_lens = forward_batch.extend_prefix_lens - use_ragged = global_server_args_dict["disable_radix_cache"] + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + use_ragged = ( + not global_server_args_dict["flashinfer_mla_disable_ragged"] + and extend_no_prefix + ) self.indices_updater_prefill.update( forward_batch.req_pool_indices, @@ -241,45 +234,37 @@ class FlashInferMLAAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): + cache_loc = forward_batch.out_cache_loc logits_soft_cap = layer.logit_cap + prefill_wrapper_paged = self.forward_metadata.prefill_wrapper + qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) + k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - if not global_server_args_dict["disable_radix_cache"]: - # use mla paged prefill - prefill_wrapper_paged = self.forward_metadata.prefill_wrapper - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) - k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + # Save kv cache + if save_kv_cache and k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if self.forward_metadata.use_ragged: + # ragged prefill + o, _ = self.prefill_wrapper_ragged.forward_return_lse( + qall, + k.view(-1, layer.tp_k_head_num, layer.head_dim), + v.view(-1, layer.tp_k_head_num, layer.v_head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) + else: + # mla paged prefill o = prefill_wrapper_paged.run( qall[:, :, : layer.v_head_dim], qall[:, :, layer.v_head_dim :], k_buf[:, :, : layer.v_head_dim], k_buf[:, :, layer.v_head_dim :], ) - else: - # use mla ragged prefill - o, _ = self.prefill_wrapper_ragged.forward_return_lse( - q.view(-1, layer.tp_q_head_num, layer.head_dim), - k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_v_head_num, layer.v_head_dim), - causal=True, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - ) - - # FIXME: Here should be another prefill_paged to call - - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - cache_loc, - k, - v, - ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) @@ -334,6 +319,7 @@ class FlashInferMLAIndicesUpdaterDecode: self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.q_indptr = attn_backend.q_indptr_decode def update( self, @@ -342,12 +328,13 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens_sum: int, decode_wrapper: BatchMLAPagedAttentionWrapper, ): - decode_wrappers = decode_wrapper or self.decode_wrapper + decode_wrapper = decode_wrapper or self.decode_wrapper self.call_begin_forward( decode_wrapper, req_pool_indices, seq_lens, seq_lens_sum, + self.q_indptr, self.kv_indptr, ) @@ -357,14 +344,19 @@ class FlashInferMLAIndicesUpdaterDecode: req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, + q_indptr: torch.Tensor, kv_indptr: torch.Tensor, ): bs = len(req_pool_indices) + q_indptr = q_indptr[: bs + 1] kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) + kv_lens = paged_kernel_lens.to(torch.int32) + sm_scale = self.scaling + create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, @@ -375,9 +367,6 @@ class FlashInferMLAIndicesUpdaterDecode: self.req_to_token.shape[1], ) - sm_scale = self.scaling - q_indptr = torch.arange(0, bs + 1).to(0).int() - kv_lens = paged_kernel_lens.to(torch.int32) wrapper.plan( q_indptr, kv_indptr, @@ -397,12 +386,9 @@ class FlashInferMLAIndicesUpdaterDecode: class FlashInferMLAIndicesUpdaterPrefill: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): # Parse Constants - self.num_qo_heads = ( + self.num_local_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) - self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - get_attention_tp_size() - ) self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim @@ -425,9 +411,7 @@ class FlashInferMLAIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, - prefill_wrapper_paged: Union[ - BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper - ], + prefill_wrapper_paged: BatchMLAPagedAttentionWrapper, use_ragged: bool, ): if use_ragged: @@ -453,9 +437,7 @@ class FlashInferMLAIndicesUpdaterPrefill: def call_begin_forward( self, wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, - wrapper_paged: Union[ - BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper - ], + wrapper_paged: BatchMLAPagedAttentionWrapper, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, @@ -466,7 +448,6 @@ class FlashInferMLAIndicesUpdaterPrefill: use_ragged: bool, ): bs = len(req_pool_indices) - # Normal extend kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( @@ -488,19 +469,18 @@ class FlashInferMLAIndicesUpdaterPrefill: qo_indptr = qo_indptr[: bs + 1] sm_scale = self.scaling - # extend part if use_ragged: + # ragged prefill wrapper_ragged.begin_forward( qo_indptr=qo_indptr, kv_indptr=qo_indptr, - num_qo_heads=self.num_qo_heads, - num_kv_heads=self.num_kv_heads, + num_qo_heads=self.num_local_heads, + num_kv_heads=self.num_local_heads, head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, head_dim_vo=self.v_head_dim, q_data_type=self.q_data_type, ) - - if not global_server_args_dict["disable_radix_cache"]: + else: # mla paged prefill kv_len_arr = kv_indptr[1:] - kv_indptr[:-1] wrapper_paged.plan( @@ -508,7 +488,7 @@ class FlashInferMLAIndicesUpdaterPrefill: kv_indptr, kv_indices, kv_len_arr, - self.num_qo_heads, + self.num_local_heads, self.kv_lora_rank, self.qk_rope_head_dim, 1, @@ -517,5 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill: self.q_data_type, self.data_type, ) - - # FIXME: Here should be some logic for prefill paged when not using radix cache? diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ea7280485..a0db44c71 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -67,6 +67,7 @@ global_server_args_dict = { "device": ServerArgs.device, "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "disable_radix_cache": ServerArgs.disable_radix_cache, + "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, } logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 51a311c04..4b486ec8b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -182,6 +182,7 @@ class ModelRunner: "device": server_args.device, "enable_flashinfer_mla": server_args.enable_flashinfer_mla, "disable_radix_cache": server_args.disable_radix_cache, + "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, } ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f5182c828..c68403ea9 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -520,10 +520,11 @@ class DeepseekV2AttentionMLA(nn.Module): def no_absorb() -> bool: if global_server_args_dict["enable_flashinfer_mla"]: - # Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache + # Flashinfer MLA: Do not absorb when enabling ragged prefill return ( - global_server_args_dict["disable_radix_cache"] + not global_server_args_dict["flashinfer_mla_disable_ragged"] and forward_batch.forward_mode.is_extend() + and forward_batch.extend_prefix_lens.sum() == 0 ) else: # Triton: Use normal computation for prefill and use weight absorption for extend/decode diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c62a3dbda..ddb10e390 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -167,6 +167,7 @@ class ServerArgs: tool_call_parser: str = None enable_hierarchical_cache: bool = False enable_flashinfer_mla: bool = False + flashinfer_mla_disable_ragged: bool = False def __post_init__(self): # Set missing default values @@ -713,6 +714,11 @@ class ServerArgs: action="store_true", help="Enable FlashInfer MLA optimization", ) + parser.add_argument( + "--flashinfer-mla-disable-ragged", + action="store_true", + help="Not using ragged prefill wrapper when running flashinfer mla", + ) # Speculative decoding parser.add_argument( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b02bbec56..8aa8e0fd1 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -23,6 +23,7 @@ suites = { "test_gguf.py", "test_input_embeddings.py", "test_mla.py", + "test_mla_flashinfer.py", "test_mla_fp8.py", "test_json_constrained.py", "test_large_max_new_tokens.py", diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py new file mode 100644 index 000000000..fc29e958f --- /dev/null +++ b/test/srt/test_mla_flashinfer.py @@ -0,0 +1,104 @@ +import unittest +from types import SimpleNamespace + +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestFlashinferMLA(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "sgl-project/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + "--enable-flashinfer-mla", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +class TestFlashinferMLANoRagged(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "sgl-project/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--enable-torch-compile", + "--disable-cuda-graph", + "--cuda-graph-max-bs", + "2", + "--enable-flashinfer-mla", + "--flashinfer-mla-disable-ragged", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +if __name__ == "__main__": + unittest.main()