Support MHA with chunked prefix cache for DeepSeek chunked prefill (#5113)
This commit is contained in:
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -593,41 +593,87 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
||||
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
||||
k_rope_cache = k_rope.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
layer.head_dim - layer.v_head_dim,
|
||||
)
|
||||
c_kv_cache = c_kv.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
if (
|
||||
not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
and forward_batch.attn_attend_prefix_cache is not None
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
# Do multi-head attention with chunked prefix cache
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
if forward_batch.attn_attend_prefix_cache:
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
||||
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
assert chunk_idx >= 0
|
||||
|
||||
output, lse, *rest = flash_attn_varlen_func(
|
||||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
else:
|
||||
# MHA for extend part of sequence without attending prefix kv cache
|
||||
output, lse, *rest = flash_attn_varlen_func(
|
||||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k=metadata.cu_seqlens_q,
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
max_seqlen_k=metadata.max_seq_len_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
return output, lse
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
||||
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
||||
k_rope_cache = k_rope.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
layer.head_dim - layer.v_head_dim,
|
||||
)
|
||||
c_kv_cache = c_kv.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user