feat: support flashinfer mla with prefix cache (#3643)
This commit is contained in:
@@ -54,7 +54,9 @@ class DecodeMetadata:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PrefillMetadata:
|
class PrefillMetadata:
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
prefill_wrappers: List[
|
||||||
|
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||||
|
]
|
||||||
use_ragged: bool
|
use_ragged: bool
|
||||||
extend_no_prefix: bool
|
extend_no_prefix: bool
|
||||||
|
|
||||||
@@ -160,16 +162,36 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.decode_wrappers = []
|
self.decode_wrappers = []
|
||||||
for _ in range(self.num_wrappers):
|
for _ in range(self.num_wrappers):
|
||||||
if not skip_prefill:
|
if not skip_prefill:
|
||||||
self.prefill_wrappers_paged.append(
|
if (
|
||||||
BatchPrefillWithPagedKVCacheWrapper(
|
self.enable_flashinfer_mla
|
||||||
self.workspace_buffer,
|
and not global_server_args_dict["disable_radix_cache"]
|
||||||
"NHD",
|
):
|
||||||
backend="fa2",
|
# use mla paged prefill
|
||||||
|
self.prefill_wrappers_paged.append(
|
||||||
|
BatchMLAPagedAttentionWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
backend="fa2",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.prefill_wrappers_verify.append(
|
||||||
|
BatchMLAPagedAttentionWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
backend="fa2",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.prefill_wrappers_paged.append(
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
backend="fa2",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.prefill_wrappers_verify.append(
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
self.prefill_wrappers_verify.append(
|
|
||||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
|
||||||
)
|
|
||||||
if self.enable_flashinfer_mla:
|
if self.enable_flashinfer_mla:
|
||||||
self.decode_wrappers.append(
|
self.decode_wrappers.append(
|
||||||
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
||||||
@@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
if self.is_multimodal:
|
if self.is_multimodal or (
|
||||||
|
self.enable_flashinfer_mla
|
||||||
|
and not global_server_args_dict["disable_radix_cache"]
|
||||||
|
):
|
||||||
use_ragged = False
|
use_ragged = False
|
||||||
extend_no_prefix = False
|
extend_no_prefix = False
|
||||||
else:
|
else:
|
||||||
@@ -419,23 +444,43 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
logits_soft_cap = layer.logit_cap
|
logits_soft_cap = layer.logit_cap
|
||||||
|
|
||||||
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
if global_server_args_dict["disable_radix_cache"]:
|
||||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
# use mla ragged prefill
|
||||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
causal=True,
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||||
sm_scale=layer.scaling,
|
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
||||||
logits_soft_cap=logits_soft_cap,
|
causal=True,
|
||||||
)
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
o = o1
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# use mla paged prefill
|
||||||
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||||
|
self._get_wrapper_idx(layer)
|
||||||
|
]
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v
|
||||||
|
)
|
||||||
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
|
||||||
if save_kv_cache:
|
o = prefill_wrapper_paged.run(
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
qall[:, :, : layer.v_head_dim],
|
||||||
layer,
|
qall[:, :, layer.v_head_dim :],
|
||||||
cache_loc,
|
k_buf[:, :, : layer.v_head_dim],
|
||||||
k,
|
k_buf[:, :, layer.v_head_dim :],
|
||||||
v,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
@@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[
|
||||||
|
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||||
|
],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
@@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[
|
||||||
|
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||||
|
],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
@@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
wrapper_paged: Union[
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
||||||
|
],
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
@@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
custom_mask=custom_mask,
|
custom_mask=custom_mask,
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
global_config.enable_flashinfer_mla
|
||||||
|
and not global_server_args_dict["disable_radix_cache"]
|
||||||
|
):
|
||||||
|
# mla paged prefill
|
||||||
|
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
||||||
|
wrapper_paged.plan(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_len_arr,
|
||||||
|
self.num_qo_heads,
|
||||||
|
512,
|
||||||
|
64,
|
||||||
|
1,
|
||||||
|
True,
|
||||||
|
1 / math.sqrt(192),
|
||||||
|
self.data_type,
|
||||||
|
self.data_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMultiStepDraftBackend:
|
class FlashInferMultiStepDraftBackend:
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ global_server_args_dict = {
|
|||||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||||
"device": ServerArgs.device,
|
"device": ServerArgs.device,
|
||||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||||
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ class ModelRunner:
|
|||||||
"enable_ep_moe": server_args.enable_ep_moe,
|
"enable_ep_moe": server_args.enable_ep_moe,
|
||||||
"device": server_args.device,
|
"device": server_args.device,
|
||||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||||
|
"disable_radix_cache": server_args.disable_radix_cache,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||||
if forward_batch.forward_mode.is_extend():
|
if global_server_args_dict["disable_radix_cache"]:
|
||||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
if forward_batch.forward_mode.is_extend():
|
||||||
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||||
|
else:
|
||||||
|
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||||
else:
|
else:
|
||||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user