feat: enable ragged fa3 by default on hopper 12.4+ (#3442)
This commit is contained in:
@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||||
|
|
||||||
# Parse constants
|
# Parse constants
|
||||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||||
@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
for _ in range(self.num_wrappers)
|
for _ in range(self.num_wrappers)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create wrappers
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
# NOTE: we do not use ragged attention when there are multiple wrappers
|
self.workspace_buffer, "NHD"
|
||||||
self.prefill_wrapper_ragged = (
|
|
||||||
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
|
||||||
if self.num_wrappers == 1
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Two wrappers: one for sliding window attention and one for full attention.
|
# Two wrappers: one for sliding window attention and one for full attention.
|
||||||
@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
# Some heuristics to check whether to use ragged forward
|
if self.is_multimodal:
|
||||||
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
|
||||||
use_ragged = True
|
|
||||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
|
||||||
else:
|
|
||||||
use_ragged = False
|
use_ragged = False
|
||||||
extend_no_prefix = False
|
extend_no_prefix = False
|
||||||
|
else:
|
||||||
|
use_ragged = True
|
||||||
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
|
||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
wrapper.end_forward()
|
|
||||||
wrapper.begin_forward(
|
wrapper.begin_forward(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
1,
|
1,
|
||||||
data_type=self.data_type,
|
data_type=self.data_type,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
|
|
||||||
# extend part
|
# extend part
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
wrapper_ragged.end_forward()
|
|
||||||
wrapper_ragged.begin_forward(
|
wrapper_ragged.begin_forward(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# cached part
|
# cached part
|
||||||
wrapper_paged.end_forward()
|
|
||||||
wrapper_paged.begin_forward(
|
wrapper_paged.begin_forward(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
1,
|
1,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
custom_mask=custom_mask,
|
custom_mask=custom_mask,
|
||||||
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1125,6 +1121,7 @@ def fast_decode_plan(
|
|||||||
sm_scale: Optional[float] = None,
|
sm_scale: Optional[float] = None,
|
||||||
rope_scale: Optional[float] = None,
|
rope_scale: Optional[float] = None,
|
||||||
rope_theta: Optional[float] = None,
|
rope_theta: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
||||||
batch_size = len(last_page_len)
|
batch_size = len(last_page_len)
|
||||||
|
|||||||
Reference in New Issue
Block a user