Support MHA with chunked prefix cache for flashinfer/flashmla backend, support page size > 1 for MHA chunked prefix (#8616)
Co-authored-by: xuyongfei.xyf <xuyongfei.xyf@antgroup.com>
This commit is contained in:
@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
o = result
|
o = result
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
not global_server_args_dict["disable_chunked_prefix_cache"]
|
forward_batch.attn_attend_prefix_cache is not None
|
||||||
and forward_batch.attn_attend_prefix_cache is not None
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
):
|
):
|
||||||
# Do multi-head attention with chunked prefix cache
|
# Do multi-head attention with chunked prefix cache
|
||||||
|
|
||||||
if forward_batch.attn_attend_prefix_cache:
|
if forward_batch.attn_attend_prefix_cache:
|
||||||
|
assert not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||||
# MHA for chunked prefix kv cache when running model with MLA
|
# MHA for chunked prefix kv cache when running model with MLA
|
||||||
assert forward_batch.prefix_chunk_idx is not None
|
assert forward_batch.prefix_chunk_idx is not None
|
||||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||||
@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
chunk_idx = forward_batch.prefix_chunk_idx
|
chunk_idx = forward_batch.prefix_chunk_idx
|
||||||
assert chunk_idx >= 0
|
assert chunk_idx >= 0
|
||||||
|
|
||||||
output, lse, *rest = flash_attn_varlen_func(
|
assert forward_batch.mha_return_lse
|
||||||
|
output = flash_attn_varlen_func(
|
||||||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
||||||
@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# MHA for extend part of sequence without attending prefix kv cache
|
# MHA for extend part of sequence without attending prefix kv cache
|
||||||
output, lse, *rest = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
||||||
@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
max_seqlen_k=metadata.max_seq_len_q,
|
max_seqlen_k=metadata.max_seq_len_q,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=forward_batch.mha_return_lse,
|
||||||
)
|
)
|
||||||
return output, lse
|
if forward_batch.mha_return_lse:
|
||||||
|
output, lse, *rest = output
|
||||||
|
lse = torch.transpose(lse, 0, 1).contiguous()
|
||||||
|
return output, lse
|
||||||
|
return output
|
||||||
else:
|
else:
|
||||||
# Do absorbed multi-latent attention
|
# Do absorbed multi-latent attention
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||||
|
|||||||
@@ -59,6 +59,115 @@ class PrefillMetadata:
|
|||||||
global_workspace_buffer = None
|
global_workspace_buffer = None
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferMhaChunkKVRunner:
|
||||||
|
def __init__(
|
||||||
|
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
|
||||||
|
):
|
||||||
|
# Parse Constants
|
||||||
|
self.num_local_heads = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||||
|
self.v_head_dim = model_runner.model_config.v_head_dim
|
||||||
|
self.data_type = model_runner.dtype
|
||||||
|
self.q_data_type = model_runner.dtype
|
||||||
|
|
||||||
|
# Buffers and wrappers
|
||||||
|
self.qo_indptr = attn_backend.qo_indptr
|
||||||
|
self.workspace_buffer = attn_backend.workspace_buffer
|
||||||
|
self.fmha_backend = attn_backend.fmha_backend
|
||||||
|
|
||||||
|
self.chunk_ragged_wrappers = []
|
||||||
|
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
|
||||||
|
|
||||||
|
def update_prefix_chunks(self, num_prefix_chunks: int):
|
||||||
|
while num_prefix_chunks > len(self.chunk_ragged_wrappers):
|
||||||
|
ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
|
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
||||||
|
)
|
||||||
|
self.chunk_ragged_wrappers.append(ragged_wrapper)
|
||||||
|
|
||||||
|
def update_wrapper(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
|
assert forward_batch.num_prefix_chunks is not None
|
||||||
|
num_prefix_chunks = forward_batch.num_prefix_chunks
|
||||||
|
self.update_prefix_chunks(num_prefix_chunks)
|
||||||
|
|
||||||
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
seq_lens = forward_batch.seq_lens
|
||||||
|
|
||||||
|
bs = len(seq_lens)
|
||||||
|
qo_indptr = self.qo_indptr
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
|
|
||||||
|
for chunk_idx in range(forward_batch.num_prefix_chunks):
|
||||||
|
# MHA for chunked prefix kv cache when running model with MLA
|
||||||
|
assert forward_batch.prefix_chunk_idx is not None
|
||||||
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||||
|
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
||||||
|
|
||||||
|
kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
|
||||||
|
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
||||||
|
wrapper.begin_forward(
|
||||||
|
qo_indptr=qo_indptr,
|
||||||
|
kv_indptr=kv_indptr,
|
||||||
|
num_qo_heads=self.num_local_heads,
|
||||||
|
num_kv_heads=self.num_local_heads,
|
||||||
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||||
|
head_dim_vo=self.v_head_dim,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
# ragged prefill
|
||||||
|
self.ragged_wrapper.begin_forward(
|
||||||
|
qo_indptr=qo_indptr,
|
||||||
|
kv_indptr=qo_indptr,
|
||||||
|
num_qo_heads=self.num_local_heads,
|
||||||
|
num_kv_heads=self.num_local_heads,
|
||||||
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||||
|
head_dim_vo=self.v_head_dim,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
|
logits_soft_cap = layer.logit_cap
|
||||||
|
if forward_batch.attn_attend_prefix_cache:
|
||||||
|
chunk_idx = forward_batch.prefix_chunk_idx
|
||||||
|
assert chunk_idx >= 0
|
||||||
|
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
||||||
|
o1, s1 = wrapper.forward_return_lse(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||||
|
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
||||||
|
causal=False,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o1, s1 = self.ragged_wrapper.forward_return_lse(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||||
|
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
||||||
|
causal=True,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o1, s1
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMLAAttnBackend(AttentionBackend):
|
class FlashInferMLAAttnBackend(AttentionBackend):
|
||||||
"""Flashinfer attention kernels."""
|
"""Flashinfer attention kernels."""
|
||||||
|
|
||||||
@@ -74,6 +183,12 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
self.skip_prefill = skip_prefill
|
self.skip_prefill = skip_prefill
|
||||||
|
self.enable_chunk_kv = (
|
||||||
|
not skip_prefill
|
||||||
|
and global_server_args_dict["disaggregation_mode"] != "decode"
|
||||||
|
and not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||||
|
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||||
|
)
|
||||||
self.page_size = model_runner.page_size
|
self.page_size = model_runner.page_size
|
||||||
|
|
||||||
# Allocate buffers
|
# Allocate buffers
|
||||||
@@ -117,11 +232,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
self.q_indptr_decode = q_indptr_decode_buf
|
self.q_indptr_decode = q_indptr_decode_buf
|
||||||
|
|
||||||
fmha_backend = "auto"
|
self.fmha_backend = "auto"
|
||||||
if is_sm100_supported():
|
if is_sm100_supported():
|
||||||
fmha_backend = "cutlass"
|
self.fmha_backend = "cutlass"
|
||||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
self.workspace_buffer, "NHD", backend=fmha_backend
|
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
@@ -145,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
||||||
model_runner, self
|
model_runner, self
|
||||||
)
|
)
|
||||||
|
if self.enable_chunk_kv:
|
||||||
|
self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
|
||||||
|
|
||||||
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
||||||
model_runner, self
|
model_runner, self
|
||||||
@@ -373,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
"""Init the metadata for a forward pass."""
|
||||||
|
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@@ -384,6 +505,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
if (
|
||||||
|
forward_batch.attn_attend_prefix_cache is not None
|
||||||
|
and forward_batch.mha_return_lse
|
||||||
|
): # MHA Chunk
|
||||||
|
assert self.enable_chunk_kv
|
||||||
|
assert q_rope is None
|
||||||
|
assert k_rope is None
|
||||||
|
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
|
||||||
|
return o1, s1
|
||||||
|
|
||||||
cache_loc = forward_batch.out_cache_loc
|
cache_loc = forward_batch.out_cache_loc
|
||||||
logits_soft_cap = layer.logit_cap
|
logits_soft_cap = layer.logit_cap
|
||||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
||||||
@@ -412,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
k = torch.cat([k, k_rope], dim=-1)
|
k = torch.cat([k, k_rope], dim=-1)
|
||||||
o = self.prefill_wrapper_ragged.forward(
|
o = self.prefill_wrapper_ragged.forward(
|
||||||
qall,
|
qall,
|
||||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||||||
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
@@ -732,6 +863,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||||
head_dim_vo=self.v_head_dim,
|
head_dim_vo=self.v_head_dim,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# mla paged prefill
|
# mla paged prefill
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"enable_symm_mem",
|
"enable_symm_mem",
|
||||||
"quantization",
|
"quantization",
|
||||||
"enable_custom_logit_processor",
|
"enable_custom_logit_processor",
|
||||||
|
"disaggregation_mode",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
|
|||||||
@@ -241,6 +241,9 @@ class ForwardBatch:
|
|||||||
prefix_chunk_num_tokens: Optional[List[int]] = None
|
prefix_chunk_num_tokens: Optional[List[int]] = None
|
||||||
# KV Indices for each chunk
|
# KV Indices for each chunk
|
||||||
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
||||||
|
# For MLA chunked prefix cache used in chunked prefill
|
||||||
|
# Tell attention backend whether lse needs to be returned
|
||||||
|
mha_return_lse: Optional[bool] = None
|
||||||
|
|
||||||
# For multimodal
|
# For multimodal
|
||||||
mm_inputs: Optional[List[MultimodalInputs]] = None
|
mm_inputs: Optional[List[MultimodalInputs]] = None
|
||||||
|
|||||||
@@ -518,9 +518,6 @@ class ModelRunner:
|
|||||||
|
|
||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
server_args.disable_chunked_prefix_cache = True
|
server_args.disable_chunked_prefix_cache = True
|
||||||
elif self.page_size > 1:
|
|
||||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
|
||||||
server_args.disable_chunked_prefix_cache = True
|
|
||||||
|
|
||||||
if not server_args.disable_chunked_prefix_cache:
|
if not server_args.disable_chunked_prefix_cache:
|
||||||
logger.info("Chunked prefix cache is turned on.")
|
logger.info("Chunked prefix cache is turned on.")
|
||||||
|
|||||||
@@ -995,29 +995,31 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
if attention_backend == "ascend":
|
if attention_backend == "ascend":
|
||||||
return AttnForwardMethod.MLA
|
return AttnForwardMethod.MLA
|
||||||
elif attention_backend == "flashinfer":
|
elif (
|
||||||
|
attention_backend == "flashinfer"
|
||||||
|
or attention_backend == "fa3"
|
||||||
|
or attention_backend == "flashmla"
|
||||||
|
):
|
||||||
|
# Use MHA with chunked KV cache when prefilling on long sequences.
|
||||||
|
sum_extend_prefix_lens = (
|
||||||
|
sum(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
if forward_batch.extend_prefix_lens_cpu is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||||
|
disable_ragged = (
|
||||||
|
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
||||||
|
) and self.flashinfer_mla_disable_ragged
|
||||||
if (
|
if (
|
||||||
not self.flashinfer_mla_disable_ragged
|
not disable_ragged
|
||||||
and forward_batch.forward_mode.is_extend()
|
and forward_batch.forward_mode.is_extend()
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
|
||||||
):
|
|
||||||
return AttnForwardMethod.MHA
|
|
||||||
else:
|
|
||||||
return _dispatch_mla_subtype()
|
|
||||||
elif attention_backend == "fa3":
|
|
||||||
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
|
||||||
if forward_batch.extend_prefix_lens_cpu is not None:
|
|
||||||
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
|
||||||
if (
|
|
||||||
forward_batch.forward_mode.is_extend()
|
|
||||||
and not self.disable_chunked_prefix_cache
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
and (
|
and (
|
||||||
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
(
|
||||||
|
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
||||||
|
and not self.disable_chunked_prefix_cache
|
||||||
|
)
|
||||||
or sum_extend_prefix_lens == 0
|
or sum_extend_prefix_lens == 0
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@@ -1685,7 +1687,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
k[..., self.qk_nope_head_dim :] = k_pe
|
k[..., self.qk_nope_head_dim :] = k_pe
|
||||||
|
|
||||||
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||||
lse = torch.transpose(lse, 0, 1).contiguous()
|
|
||||||
tmp_output = torch.empty_like(accum_output)
|
tmp_output = torch.empty_like(accum_output)
|
||||||
tmp_lse = torch.empty_like(accum_lse)
|
tmp_lse = torch.empty_like(accum_lse)
|
||||||
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
||||||
@@ -1707,55 +1708,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
# will be helpful for understanding the purpose of this function.
|
# will be helpful for understanding the purpose of this function.
|
||||||
|
|
||||||
# First do normal mha forward to get output for extended part
|
# First do normal mha forward to get output for extended part
|
||||||
if self.q_lora_rank is not None:
|
return self.forward_normal_prepare(
|
||||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
positions, hidden_states, forward_batch, zero_allocator
|
||||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
|
||||||
)
|
|
||||||
q = self.q_a_layernorm(q)
|
|
||||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
|
||||||
else:
|
|
||||||
q = self.q_proj(hidden_states)[0].view(
|
|
||||||
-1, self.num_local_heads, self.qk_head_dim
|
|
||||||
)
|
|
||||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
||||||
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
||||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
||||||
latent_cache = latent_cache.unsqueeze(1)
|
|
||||||
kv_a = self.kv_a_layernorm(kv_a)
|
|
||||||
kv = self.kv_b_proj(kv_a)[0]
|
|
||||||
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
||||||
k_nope = kv[..., : self.qk_nope_head_dim]
|
|
||||||
v = kv[..., self.qk_nope_head_dim :]
|
|
||||||
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
|
||||||
|
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
|
||||||
q[..., self.qk_nope_head_dim :] = q_pe
|
|
||||||
k = torch.empty_like(q)
|
|
||||||
k[..., : self.qk_nope_head_dim] = k_nope
|
|
||||||
k[..., self.qk_nope_head_dim :] = k_pe
|
|
||||||
|
|
||||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
|
||||||
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
|
||||||
|
|
||||||
# Save latent cache
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return q, k, v, forward_batch
|
|
||||||
|
|
||||||
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
||||||
|
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
# Only initialize the info once
|
||||||
|
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
|
||||||
|
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
||||||
|
if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
|
||||||
|
forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
|
||||||
|
|
||||||
|
forward_batch.mha_return_lse = has_extend_prefix
|
||||||
# Do mha for extended part without prefix
|
# Do mha for extended part without prefix
|
||||||
forward_batch.set_attn_attend_prefix_cache(False)
|
forward_batch.set_attn_attend_prefix_cache(False)
|
||||||
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||||
lse = torch.transpose(lse, 0, 1).contiguous()
|
|
||||||
|
|
||||||
# Do mha attention with chunked prefix cache if there are any sequence with prefix
|
# Do mha attention with chunked prefix cache if there are any sequence with prefix
|
||||||
if any(forward_batch.extend_prefix_lens_cpu):
|
if has_extend_prefix:
|
||||||
# Only initialize the info once
|
attn_output, lse = attn_output
|
||||||
if forward_batch.num_prefix_chunks is None:
|
|
||||||
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
|
||||||
|
|
||||||
forward_batch.set_attn_attend_prefix_cache(True)
|
forward_batch.set_attn_attend_prefix_cache(True)
|
||||||
attn_output = self._chunked_prefix_attn_mha(
|
attn_output = self._chunked_prefix_attn_mha(
|
||||||
q=q,
|
q=q,
|
||||||
|
|||||||
Reference in New Issue
Block a user