From 714f3e6362791ccc54a8845e5c6261d1e6d156cc Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 18 Feb 2025 02:06:43 +0800 Subject: [PATCH] feat: support flashinfer mla with prefix cache (#3643) --- .../layers/attention/flashinfer_backend.py | 129 ++++++++++++++---- python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/models/deepseek_v2.py | 7 +- 4 files changed, 107 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 2c4c6c65b..a3e194ccb 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -54,7 +54,9 @@ class DecodeMetadata: @dataclass class PrefillMetadata: - prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper] + prefill_wrappers: List[ + Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] + ] use_ragged: bool extend_no_prefix: bool @@ -160,16 +162,36 @@ class FlashInferAttnBackend(AttentionBackend): self.decode_wrappers = [] for _ in range(self.num_wrappers): if not skip_prefill: - self.prefill_wrappers_paged.append( - BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - backend="fa2", + if ( + self.enable_flashinfer_mla + and not global_server_args_dict["disable_radix_cache"] + ): + # use mla paged prefill + self.prefill_wrappers_paged.append( + BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + backend="fa2", + ) + ) + self.prefill_wrappers_verify.append( + BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + backend="fa2", + ) + ) + else: + self.prefill_wrappers_paged.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + backend="fa2", + ) + ) + self.prefill_wrappers_verify.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD" + ) ) - ) - self.prefill_wrappers_verify.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - ) if self.enable_flashinfer_mla: self.decode_wrappers.append( BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2") @@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend): else: prefix_lens = forward_batch.extend_prefix_lens - if self.is_multimodal: + if self.is_multimodal or ( + self.enable_flashinfer_mla + and not global_server_args_dict["disable_radix_cache"] + ): use_ragged = False extend_no_prefix = False else: @@ -419,23 +444,43 @@ class FlashInferAttnBackend(AttentionBackend): logits_soft_cap = layer.logit_cap - o1, _ = self.prefill_wrapper_ragged.forward_return_lse( - q.view(-1, layer.tp_q_head_num, layer.head_dim), - k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_v_head_num, layer.v_head_dim), - causal=True, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - ) + if global_server_args_dict["disable_radix_cache"]: + # use mla ragged prefill + o, _ = self.prefill_wrapper_ragged.forward_return_lse( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim), + v.view(-1, layer.tp_v_head_num, layer.v_head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) - o = o1 + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) + else: + # use mla paged prefill + prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ + self._get_wrapper_idx(layer) + ] + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v + ) + qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) + k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - cache_loc, - k, - v, + o = prefill_wrapper_paged.run( + qall[:, :, : layer.v_head_dim], + qall[:, :, layer.v_head_dim :], + k_buf[:, :, : layer.v_head_dim], + k_buf[:, :, layer.v_head_dim :], ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) @@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, - prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], + prefill_wrappers: List[ + Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] + ], use_ragged: bool, encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], @@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, - prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], + prefill_wrappers: List[ + Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] + ], use_ragged: bool, encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], @@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill: def call_begin_forward( self, wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, - wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, + wrapper_paged: Union[ + BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper + ], req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, @@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill: custom_mask=custom_mask, non_blocking=True, ) + elif ( + global_config.enable_flashinfer_mla + and not global_server_args_dict["disable_radix_cache"] + ): + # mla paged prefill + kv_len_arr = kv_indptr[1:] - kv_indptr[:-1] + wrapper_paged.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_len_arr, + self.num_qo_heads, + 512, + 64, + 1, + True, + 1 / math.sqrt(192), + self.data_type, + self.data_type, + ) class FlashInferMultiStepDraftBackend: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8ff0ff7e7..f4ffed10b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -66,6 +66,7 @@ global_server_args_dict = { "enable_ep_moe": ServerArgs.enable_ep_moe, "device": ServerArgs.device, "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, + "disable_radix_cache": ServerArgs.disable_radix_cache, } logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3242c0d61..b51c5161b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -177,6 +177,7 @@ class ModelRunner: "enable_ep_moe": server_args.enable_ep_moe, "device": server_args.device, "enable_flashinfer_mla": server_args.enable_flashinfer_mla, + "disable_radix_cache": server_args.disable_radix_cache, } ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index df4f9ed14..5778e6e4d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch: ForwardBatch, ) -> torch.Tensor: if global_server_args_dict["enable_flashinfer_mla"]: - if forward_batch.forward_mode.is_extend(): - return self.forward_normal(positions, hidden_states, forward_batch) + if global_server_args_dict["disable_radix_cache"]: + if forward_batch.forward_mode.is_extend(): + return self.forward_normal(positions, hidden_states, forward_batch) + else: + return self.forward_absorb(positions, hidden_states, forward_batch) else: return self.forward_absorb(positions, hidden_states, forward_batch) else: