Support MHA with chunked prefix cache for DeepSeek chunked prefill (#5113)

This commit is contained in:
Baizhou Zhang
2025-04-15 22:01:22 -07:00
committed by GitHub
parent dd83e7e9c3
commit a42736bbb8
10 changed files with 734 additions and 46 deletions

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass
@@ -593,41 +593,87 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
if (
not global_server_args_dict["disable_chunked_prefix_cache"]
and forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
# Do multi-head attention with chunked prefix cache
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
output, lse, *rest = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
output, lse, *rest = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=True,
)
return output, lse
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,