[Feature]Support ragged prefill in flashinfer mla backend (#3967)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
Baizhou Zhang
2025-02-28 18:13:56 -08:00
committed by GitHub
parent f3b99f73b3
commit 90a4b7d98a
9 changed files with 308 additions and 407 deletions

View File

@@ -37,7 +37,6 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.mla import BatchMLAPagedAttentionWrapper
class WrapperDispatch(Enum):
@@ -47,16 +46,12 @@ class WrapperDispatch(Enum):
@dataclass
class DecodeMetadata:
decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
]
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@dataclass
class PrefillMetadata:
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
]
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
use_ragged: bool
extend_no_prefix: bool
@@ -109,12 +104,6 @@ class FlashInferAttnBackend(AttentionBackend):
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
self.enable_flashinfer_mla = False
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
if global_server_args_dict["enable_flashinfer_mla"]:
self.enable_flashinfer_mla = True
global_config.enable_flashinfer_mla = True
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
@@ -132,13 +121,6 @@ class FlashInferAttnBackend(AttentionBackend):
)
for _ in range(self.num_wrappers)
]
if self.enable_flashinfer_mla:
self.qo_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
else:
assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf]
@@ -162,48 +144,24 @@ class FlashInferAttnBackend(AttentionBackend):
self.decode_wrappers = []
for _ in range(self.num_wrappers):
if not skip_prefill:
if (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
# use mla paged prefill
self.prefill_wrappers_paged.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
else:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
)
if self.enable_flashinfer_mla:
self.decode_wrappers.append(
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
)
else:
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
)
# Create indices updater
if not skip_prefill:
@@ -259,10 +217,7 @@ class FlashInferAttnBackend(AttentionBackend):
else:
prefix_lens = forward_batch.extend_prefix_lens
if self.is_multimodal or (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
if self.is_multimodal:
use_ragged = False
extend_no_prefix = False
else:
@@ -321,32 +276,20 @@ class FlashInferAttnBackend(AttentionBackend):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
for i in range(self.num_wrappers):
if self.enable_flashinfer_mla:
decode_wrappers.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.qo_indptr[i][: num_tokens + 1],
kv_indptr=self.kv_indptr[i][: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices[i],
kv_len_arr=self.kv_last_page_len[:num_tokens],
backend="fa2",
)
)
else:
decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[
:num_tokens
],
)
decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[
:num_tokens
],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
@@ -435,114 +378,64 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if global_config.enable_flashinfer_mla:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
logits_soft_cap = layer.logit_cap
if global_server_args_dict["disable_radix_cache"]:
# use mla ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
# use mla paged prefill
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
logits_soft_cap = layer.logit_cap
if not self.forward_metadata.use_ragged:
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
if self.forward_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
logits_soft_cap = layer.logit_cap
if not self.forward_metadata.use_ragged:
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
if self.forward_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(
self,
@@ -562,45 +455,23 @@ class FlashInferAttnBackend(AttentionBackend):
else forward_batch.encoder_out_cache_loc
)
if self.enable_flashinfer_mla:
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :],
reshaped_k[:, :, : layer.v_head_dim],
reshaped_k[:, :, layer.v_head_dim :],
)
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def _get_wrapper_idx(self, layer: RadixAttention):
if self.num_wrappers == 1:
@@ -648,9 +519,7 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
@@ -662,9 +531,7 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
@@ -745,9 +612,7 @@ class FlashInferIndicesUpdaterDecode:
def call_begin_forward(
self,
wrapper: Union[
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
wrapper: BatchDecodeWithPagedKVCacheWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
@@ -775,37 +640,18 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
if global_config.enable_flashinfer_mla:
sm_scale = 1.0 / math.sqrt(192)
q_indptr = torch.arange(0, bs + 1).to(0).int()
kv_lens = paged_kernel_lens.to(torch.int32)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_qo_heads,
512,
64,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
else:
wrapper.begin_forward(
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
)
wrapper.begin_forward(
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
)
class FlashInferIndicesUpdaterPrefill:
@@ -845,9 +691,7 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
@@ -861,9 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
@@ -972,9 +814,7 @@ class FlashInferIndicesUpdaterPrefill:
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
@@ -1020,62 +860,30 @@ class FlashInferIndicesUpdaterPrefill:
# extend part
if use_ragged:
if global_config.enable_flashinfer_mla:
wrapper_ragged.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim_qk=192,
head_dim_vo=128,
q_data_type=self.q_data_type,
)
else:
wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
q_data_type=self.q_data_type,
)
if not global_config.enable_flashinfer_mla:
# cached part
wrapper_paged.begin_forward(
wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
q_data_type=self.q_data_type,
custom_mask=custom_mask,
non_blocking=True,
)
elif (
global_config.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
self.num_qo_heads,
512,
64,
1,
True,
1 / math.sqrt(192),
self.data_type,
self.data_type,
)
# cached part
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
q_data_type=self.q_data_type,
custom_mask=custom_mask,
non_blocking=True,
)
class FlashInferMultiStepDraftBackend:
"""