Support MHA with chunked prefix cache for flashinfer/flashmla backend, support page size > 1 for MHA chunked prefix (#8616)
Co-authored-by: xuyongfei.xyf <xuyongfei.xyf@antgroup.com>
This commit is contained in:
@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
o = result
|
||||
else:
|
||||
if (
|
||||
not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
and forward_batch.attn_attend_prefix_cache is not None
|
||||
forward_batch.attn_attend_prefix_cache is not None
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
# Do multi-head attention with chunked prefix cache
|
||||
|
||||
if forward_batch.attn_attend_prefix_cache:
|
||||
assert not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
assert chunk_idx >= 0
|
||||
|
||||
output, lse, *rest = flash_attn_varlen_func(
|
||||
assert forward_batch.mha_return_lse
|
||||
output = flash_attn_varlen_func(
|
||||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
||||
@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
else:
|
||||
# MHA for extend part of sequence without attending prefix kv cache
|
||||
output, lse, *rest = flash_attn_varlen_func(
|
||||
output = flash_attn_varlen_func(
|
||||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
||||
@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
max_seqlen_k=metadata.max_seq_len_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
return_softmax_lse=True,
|
||||
return_softmax_lse=forward_batch.mha_return_lse,
|
||||
)
|
||||
return output, lse
|
||||
if forward_batch.mha_return_lse:
|
||||
output, lse, *rest = output
|
||||
lse = torch.transpose(lse, 0, 1).contiguous()
|
||||
return output, lse
|
||||
return output
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
|
||||
@@ -59,6 +59,115 @@ class PrefillMetadata:
|
||||
global_workspace_buffer = None
|
||||
|
||||
|
||||
class FlashInferMhaChunkKVRunner:
|
||||
def __init__(
|
||||
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
|
||||
):
|
||||
# Parse Constants
|
||||
self.num_local_heads = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.v_head_dim = model_runner.model_config.v_head_dim
|
||||
self.data_type = model_runner.dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
|
||||
# Buffers and wrappers
|
||||
self.qo_indptr = attn_backend.qo_indptr
|
||||
self.workspace_buffer = attn_backend.workspace_buffer
|
||||
self.fmha_backend = attn_backend.fmha_backend
|
||||
|
||||
self.chunk_ragged_wrappers = []
|
||||
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
|
||||
|
||||
def update_prefix_chunks(self, num_prefix_chunks: int):
|
||||
while num_prefix_chunks > len(self.chunk_ragged_wrappers):
|
||||
ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
||||
)
|
||||
self.chunk_ragged_wrappers.append(ragged_wrapper)
|
||||
|
||||
def update_wrapper(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
assert forward_batch.num_prefix_chunks is not None
|
||||
num_prefix_chunks = forward_batch.num_prefix_chunks
|
||||
self.update_prefix_chunks(num_prefix_chunks)
|
||||
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
seq_lens = forward_batch.seq_lens
|
||||
|
||||
bs = len(seq_lens)
|
||||
qo_indptr = self.qo_indptr
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
|
||||
for chunk_idx in range(forward_batch.num_prefix_chunks):
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
||||
|
||||
kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
|
||||
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
||||
wrapper.begin_forward(
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=kv_indptr,
|
||||
num_qo_heads=self.num_local_heads,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
head_dim_vo=self.v_head_dim,
|
||||
q_data_type=self.q_data_type,
|
||||
causal=False,
|
||||
)
|
||||
# ragged prefill
|
||||
self.ragged_wrapper.begin_forward(
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=qo_indptr,
|
||||
num_qo_heads=self.num_local_heads,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
head_dim_vo=self.v_head_dim,
|
||||
q_data_type=self.q_data_type,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
logits_soft_cap = layer.logit_cap
|
||||
if forward_batch.attn_attend_prefix_cache:
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
assert chunk_idx >= 0
|
||||
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
||||
o1, s1 = wrapper.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
||||
causal=False,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.ragged_wrapper.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
return o1, s1
|
||||
|
||||
|
||||
class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
"""Flashinfer attention kernels."""
|
||||
|
||||
@@ -74,6 +183,12 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.device = model_runner.device
|
||||
self.skip_prefill = skip_prefill
|
||||
self.enable_chunk_kv = (
|
||||
not skip_prefill
|
||||
and global_server_args_dict["disaggregation_mode"] != "decode"
|
||||
and not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||
)
|
||||
self.page_size = model_runner.page_size
|
||||
|
||||
# Allocate buffers
|
||||
@@ -117,11 +232,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
else:
|
||||
self.q_indptr_decode = q_indptr_decode_buf
|
||||
|
||||
fmha_backend = "auto"
|
||||
self.fmha_backend = "auto"
|
||||
if is_sm100_supported():
|
||||
fmha_backend = "cutlass"
|
||||
self.fmha_backend = "cutlass"
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD", backend=fmha_backend
|
||||
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
||||
)
|
||||
|
||||
if not self.skip_prefill:
|
||||
@@ -145,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
if self.enable_chunk_kv:
|
||||
self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
|
||||
|
||||
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
||||
model_runner, self
|
||||
@@ -373,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -384,6 +505,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if (
|
||||
forward_batch.attn_attend_prefix_cache is not None
|
||||
and forward_batch.mha_return_lse
|
||||
): # MHA Chunk
|
||||
assert self.enable_chunk_kv
|
||||
assert q_rope is None
|
||||
assert k_rope is None
|
||||
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
|
||||
return o1, s1
|
||||
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
logits_soft_cap = layer.logit_cap
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
||||
@@ -412,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
k = torch.cat([k, k_rope], dim=-1)
|
||||
o = self.prefill_wrapper_ragged.forward(
|
||||
qall,
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||
v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
@@ -732,6 +863,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
head_dim_vo=self.v_head_dim,
|
||||
q_data_type=self.q_data_type,
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
# mla paged prefill
|
||||
|
||||
@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_symm_mem",
|
||||
"quantization",
|
||||
"enable_custom_logit_processor",
|
||||
"disaggregation_mode",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
|
||||
@@ -241,6 +241,9 @@ class ForwardBatch:
|
||||
prefix_chunk_num_tokens: Optional[List[int]] = None
|
||||
# KV Indices for each chunk
|
||||
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
||||
# For MLA chunked prefix cache used in chunked prefill
|
||||
# Tell attention backend whether lse needs to be returned
|
||||
mha_return_lse: Optional[bool] = None
|
||||
|
||||
# For multimodal
|
||||
mm_inputs: Optional[List[MultimodalInputs]] = None
|
||||
|
||||
@@ -518,9 +518,6 @@ class ModelRunner:
|
||||
|
||||
if not self.use_mla_backend:
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
elif self.page_size > 1:
|
||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
|
||||
if not server_args.disable_chunked_prefix_cache:
|
||||
logger.info("Chunked prefix cache is turned on.")
|
||||
|
||||
@@ -995,29 +995,31 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
if attention_backend == "ascend":
|
||||
return AttnForwardMethod.MLA
|
||||
elif attention_backend == "flashinfer":
|
||||
elif (
|
||||
attention_backend == "flashinfer"
|
||||
or attention_backend == "fa3"
|
||||
or attention_backend == "flashmla"
|
||||
):
|
||||
# Use MHA with chunked KV cache when prefilling on long sequences.
|
||||
sum_extend_prefix_lens = (
|
||||
sum(forward_batch.extend_prefix_lens_cpu)
|
||||
if forward_batch.extend_prefix_lens_cpu is not None
|
||||
else 0
|
||||
)
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
disable_ragged = (
|
||||
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
||||
) and self.flashinfer_mla_disable_ragged
|
||||
if (
|
||||
not self.flashinfer_mla_disable_ragged
|
||||
not disable_ragged
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "fa3":
|
||||
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
||||
if forward_batch.extend_prefix_lens_cpu is not None:
|
||||
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not self.disable_chunked_prefix_cache
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and (
|
||||
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
||||
(
|
||||
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
||||
and not self.disable_chunked_prefix_cache
|
||||
)
|
||||
or sum_extend_prefix_lens == 0
|
||||
)
|
||||
):
|
||||
@@ -1685,7 +1687,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||
lse = torch.transpose(lse, 0, 1).contiguous()
|
||||
tmp_output = torch.empty_like(accum_output)
|
||||
tmp_lse = torch.empty_like(accum_lse)
|
||||
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
||||
@@ -1707,55 +1708,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
# will be helpful for understanding the purpose of this function.
|
||||
|
||||
# First do normal mha forward to get output for extended part
|
||||
if self.q_lora_rank is not None:
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
kv_a = self.kv_a_layernorm(kv_a)
|
||||
kv = self.kv_b_proj(kv_a)[0]
|
||||
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope = kv[..., : self.qk_nope_head_dim]
|
||||
v = kv[..., self.qk_nope_head_dim :]
|
||||
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q[..., self.qk_nope_head_dim :] = q_pe
|
||||
k = torch.empty_like(q)
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
||||
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
||||
|
||||
# Save latent cache
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||
return self.forward_normal_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
|
||||
return q, k, v, forward_batch
|
||||
|
||||
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
||||
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
|
||||
# Only initialize the info once
|
||||
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
|
||||
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
||||
if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
|
||||
forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
|
||||
|
||||
forward_batch.mha_return_lse = has_extend_prefix
|
||||
# Do mha for extended part without prefix
|
||||
forward_batch.set_attn_attend_prefix_cache(False)
|
||||
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||
lse = torch.transpose(lse, 0, 1).contiguous()
|
||||
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||
|
||||
# Do mha attention with chunked prefix cache if there are any sequence with prefix
|
||||
if any(forward_batch.extend_prefix_lens_cpu):
|
||||
# Only initialize the info once
|
||||
if forward_batch.num_prefix_chunks is None:
|
||||
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
||||
|
||||
if has_extend_prefix:
|
||||
attn_output, lse = attn_output
|
||||
forward_batch.set_attn_attend_prefix_cache(True)
|
||||
attn_output = self._chunked_prefix_attn_mha(
|
||||
q=q,
|
||||
|
||||
Reference in New Issue
Block a user