From 36f6fc50935e69eb3f6801aeceed30ab7591a30e Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 10 Feb 2025 07:43:01 +0800 Subject: [PATCH] feat: enable ragged fa3 by default on hopper 12.4+ (#3442) --- .../layers/attention/flashinfer_backend.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index bbe9a2e1a..75ed9b3fc 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend): ): super().__init__() + self.is_multimodal = model_runner.model_config.is_multimodal + # Parse constants self.decode_use_tensor_cores = should_use_tensor_core( kv_cache_dtype=model_runner.kv_cache_dtype, @@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend): for _ in range(self.num_wrappers) ] - # Create wrappers - # NOTE: we do not use ragged attention when there are multiple wrappers - self.prefill_wrapper_ragged = ( - BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD") - if self.num_wrappers == 1 - else None + self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + self.workspace_buffer, "NHD" ) # Two wrappers: one for sliding window attention and one for full attention. @@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend): else: prefix_lens = forward_batch.extend_prefix_lens - # Some heuristics to check whether to use ragged forward - if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1: - use_ragged = True - extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) - else: + if self.is_multimodal: use_ragged = False extend_no_prefix = False + else: + use_ragged = True + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) self.indices_updater_prefill.update( forward_batch.req_pool_indices, @@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 - wrapper.end_forward() wrapper.begin_forward( kv_indptr, kv_indices, @@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode: 1, data_type=self.data_type, q_data_type=self.q_data_type, + non_blocking=True, ) @@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill: # extend part if use_ragged: - wrapper_ragged.end_forward() wrapper_ragged.begin_forward( qo_indptr, qo_indptr, @@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill: ) # cached part - wrapper_paged.end_forward() wrapper_paged.begin_forward( qo_indptr, kv_indptr, @@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill: 1, q_data_type=self.q_data_type, custom_mask=custom_mask, + non_blocking=True, ) @@ -1125,6 +1121,7 @@ def fast_decode_plan( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, + **kwargs, ) -> None: """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.""" batch_size = len(last_page_len)