feat: support flashinfer mla with prefix cache (#3643)
This commit is contained in:
@@ -54,7 +54,9 @@ class DecodeMetadata:
|
||||
|
||||
@dataclass
|
||||
class PrefillMetadata:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
||||
prefill_wrappers: List[
|
||||
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||
]
|
||||
use_ragged: bool
|
||||
extend_no_prefix: bool
|
||||
|
||||
@@ -160,16 +162,36 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.decode_wrappers = []
|
||||
for _ in range(self.num_wrappers):
|
||||
if not skip_prefill:
|
||||
self.prefill_wrappers_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
backend="fa2",
|
||||
if (
|
||||
self.enable_flashinfer_mla
|
||||
and not global_server_args_dict["disable_radix_cache"]
|
||||
):
|
||||
# use mla paged prefill
|
||||
self.prefill_wrappers_paged.append(
|
||||
BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
backend="fa2",
|
||||
)
|
||||
)
|
||||
self.prefill_wrappers_verify.append(
|
||||
BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
backend="fa2",
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.prefill_wrappers_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
backend="fa2",
|
||||
)
|
||||
)
|
||||
self.prefill_wrappers_verify.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
)
|
||||
self.prefill_wrappers_verify.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
)
|
||||
if self.enable_flashinfer_mla:
|
||||
self.decode_wrappers.append(
|
||||
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
||||
@@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
if self.is_multimodal:
|
||||
if self.is_multimodal or (
|
||||
self.enable_flashinfer_mla
|
||||
and not global_server_args_dict["disable_radix_cache"]
|
||||
):
|
||||
use_ragged = False
|
||||
extend_no_prefix = False
|
||||
else:
|
||||
@@ -419,23 +444,43 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
logits_soft_cap = layer.logit_cap
|
||||
|
||||
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
if global_server_args_dict["disable_radix_cache"]:
|
||||
# use mla ragged prefill
|
||||
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
o = o1
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
else:
|
||||
# use mla paged prefill
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v
|
||||
)
|
||||
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
o = prefill_wrapper_paged.run(
|
||||
qall[:, :, : layer.v_head_dim],
|
||||
qall[:, :, layer.v_head_dim :],
|
||||
k_buf[:, :, : layer.v_head_dim],
|
||||
k_buf[:, :, layer.v_head_dim :],
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
@@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
prefill_wrappers: List[
|
||||
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||
],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
@@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
prefill_wrappers: List[
|
||||
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||
],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
@@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
def call_begin_forward(
|
||||
self,
|
||||
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
||||
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
||||
wrapper_paged: Union[
|
||||
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
||||
],
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
@@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
custom_mask=custom_mask,
|
||||
non_blocking=True,
|
||||
)
|
||||
elif (
|
||||
global_config.enable_flashinfer_mla
|
||||
and not global_server_args_dict["disable_radix_cache"]
|
||||
):
|
||||
# mla paged prefill
|
||||
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
||||
wrapper_paged.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
self.num_qo_heads,
|
||||
512,
|
||||
64,
|
||||
1,
|
||||
True,
|
||||
1 / math.sqrt(192),
|
||||
self.data_type,
|
||||
self.data_type,
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMultiStepDraftBackend:
|
||||
|
||||
@@ -66,6 +66,7 @@ global_server_args_dict = {
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"device": ServerArgs.device,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -177,6 +177,7 @@ class ModelRunner:
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"device": server_args.device,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
if global_server_args_dict["disable_radix_cache"]:
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user