feat: enable ragged fa3 by default on hopper 12.4+ (#3442)
This commit is contained in:
@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
|
||||
# Parse constants
|
||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||
@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
|
||||
# Create wrappers
|
||||
# NOTE: we do not use ragged attention when there are multiple wrappers
|
||||
self.prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
if self.num_wrappers == 1
|
||||
else None
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
)
|
||||
|
||||
# Two wrappers: one for sliding window attention and one for full attention.
|
||||
@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
# Some heuristics to check whether to use ragged forward
|
||||
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
||||
use_ragged = True
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
else:
|
||||
if self.is_multimodal:
|
||||
use_ragged = False
|
||||
extend_no_prefix = False
|
||||
else:
|
||||
use_ragged = True
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
wrapper.end_forward()
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
1,
|
||||
data_type=self.data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
|
||||
# extend part
|
||||
if use_ragged:
|
||||
wrapper_ragged.end_forward()
|
||||
wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
# cached part
|
||||
wrapper_paged.end_forward()
|
||||
wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
1,
|
||||
q_data_type=self.q_data_type,
|
||||
custom_mask=custom_mask,
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1125,6 +1121,7 @@ def fast_decode_plan(
|
||||
sm_scale: Optional[float] = None,
|
||||
rope_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
||||
batch_size = len(last_page_len)
|
||||
|
||||
Reference in New Issue
Block a user