diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md
index 7ae15f1f2..e4dc1b878 100644
--- a/docs/backend/server_arguments.md
+++ b/docs/backend/server_arguments.md
@@ -195,3 +195,4 @@ Please consult the documentation below to learn more about the parameters you ma
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.**
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
+* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.
diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md
index 5375008b7..8fc71e03c 100644
--- a/docs/references/deepseek.md
+++ b/docs/references/deepseek.md
@@ -92,13 +92,15 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes.
+- **Chunked Prefix Cache**: Chunked prefix cache optimization can increase throughput by cutting prefix cache into chunks, processing them with multi-head attention and merging their states. Its improvement can be significant when doing chunked prefill on long sequences. Currently this optimization is only available for FlashAttention3 backend.
+
Overall, with these optimizations, we have achieved up to **7x** acceleration in output throughput compared to the previous version.
-**Usage**: MLA optimization is enabled by default, to disable, use `--disable-mla`.
+**Usage**: MLA optimization is enabled by default. To disable MLA usage, use `--disable-mla`. To disable chunked prefix cache feature for mla, use `disable-chunked-prefix-cache`.
**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details.
diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py
index 040393c17..26e54546d 100644
--- a/python/sglang/srt/layers/attention/flashattention_backend.py
+++ b/python/sglang/srt/layers/attention/flashattention_backend.py
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
-from sgl_kernel.flash_attn import flash_attn_with_kvcache
+from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass
@@ -593,41 +593,87 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
)
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
- # Do absorbed multi-latent attention
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
- k_rope = kv_cache[:, :, layer.v_head_dim :]
- c_kv = kv_cache[:, :, : layer.v_head_dim]
- k_rope_cache = k_rope.view(
- -1,
- self.page_size,
- layer.tp_k_head_num,
- layer.head_dim - layer.v_head_dim,
- )
- c_kv_cache = c_kv.view(
- -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
- )
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
- q_nope = q_all[:, :, : layer.v_head_dim]
- q_rope = q_all[:, :, layer.v_head_dim :]
- o = flash_attn_with_kvcache(
- q=q_rope,
- k_cache=k_rope_cache,
- v_cache=c_kv_cache,
- qv=q_nope,
- page_table=page_table,
- cache_seqlens=cache_seqlens,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
- max_seqlen_q=max_seqlen_q,
- softmax_scale=layer.scaling,
- causal=True,
- softcap=layer.logit_cap,
- k_descale=k_descale,
- v_descale=v_descale,
- )
+ if (
+ not global_server_args_dict["disable_chunked_prefix_cache"]
+ and forward_batch.attn_attend_prefix_cache is not None
+ and not forward_batch.forward_mode.is_target_verify()
+ and not forward_batch.forward_mode.is_draft_extend()
+ ):
+ # Do multi-head attention with chunked prefix cache
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
+ if forward_batch.attn_attend_prefix_cache:
+ # MHA for chunked prefix kv cache when running model with MLA
+ assert forward_batch.prefix_chunk_idx is not None
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
+ assert forward_batch.prefix_chunk_max_seq_lens is not None
+
+ chunk_idx = forward_batch.prefix_chunk_idx
+ assert chunk_idx >= 0
+
+ output, lse, *rest = flash_attn_varlen_func(
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
+ cu_seqlens_q=metadata.cu_seqlens_q,
+ cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
+ max_seqlen_q=metadata.max_seq_len_q,
+ max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
+ softmax_scale=layer.scaling,
+ causal=False,
+ return_softmax_lse=True,
+ )
+ else:
+ # MHA for extend part of sequence without attending prefix kv cache
+ output, lse, *rest = flash_attn_varlen_func(
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
+ cu_seqlens_q=metadata.cu_seqlens_q,
+ cu_seqlens_k=metadata.cu_seqlens_q,
+ max_seqlen_q=metadata.max_seq_len_q,
+ max_seqlen_k=metadata.max_seq_len_q,
+ softmax_scale=layer.scaling,
+ causal=True,
+ return_softmax_lse=True,
+ )
+ return output, lse
+ else:
+ # Do absorbed multi-latent attention
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
+ k_rope_cache = k_rope.view(
+ -1,
+ self.page_size,
+ layer.tp_k_head_num,
+ layer.head_dim - layer.v_head_dim,
+ )
+ c_kv_cache = c_kv.view(
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
+ )
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
+ q_nope = q_all[:, :, : layer.v_head_dim]
+ q_rope = q_all[:, :, layer.v_head_dim :]
+ o = flash_attn_with_kvcache(
+ q=q_rope,
+ k_cache=k_rope_cache,
+ v_cache=c_kv_cache,
+ qv=q_nope,
+ page_table=page_table,
+ cache_seqlens=cache_seqlens,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
+ max_seqlen_q=max_seqlen_q,
+ softmax_scale=layer.scaling,
+ causal=True,
+ softcap=layer.logit_cap,
+ k_descale=k_descale,
+ v_descale=v_descale,
+ )
+
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index fa9f40112..6bfd69307 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -83,6 +83,7 @@ global_server_args_dict = {
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
+ "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
}
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index ba861b850..91cec3210 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -181,6 +181,28 @@ class ForwardBatch:
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
+ # For MLA chunked prefix cache used in chunked prefill
+ # Tell attention backend whether the kv cache needs to be attended in current pass
+ attn_attend_prefix_cache: Optional[bool] = None
+ # Number of prefix cache chunks
+ num_prefix_chunks: Optional[int] = None
+ # Index of current chunk, used by attention backend
+ prefix_chunk_idx: Optional[int] = None
+ # Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
+ prefix_chunk_len: Optional[int] = None
+ # Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
+ prefix_chunk_starts: Optional[torch.Tensor] = None
+ # Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
+ prefix_chunk_seq_lens: Optional[torch.Tensor] = None
+ # Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
+ prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
+ # Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
+ prefix_chunk_max_seq_lens: Optional[List[int]] = None
+ # Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
+ prefix_chunk_num_tokens: Optional[List[int]] = None
+ # KV Indices for each chunk
+ prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
+
# For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None
@@ -484,6 +506,128 @@ class ForwardBatch:
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
+ def get_max_chunk_capacity(self):
+ # Maximum number of tokens in each chunk
+ # TODO: Should be changed to a better value, maybe passed through server args
+ return 128 * 1024
+
+ def set_prefix_chunk_idx(self, idx: int):
+ self.prefix_chunk_idx = idx
+
+ def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
+ self.attn_attend_prefix_cache = attn_attend_prefix_cache
+
+ def prepare_chunked_kv_indices(self, device: torch.device):
+ self.prefix_chunk_kv_indices = []
+ for idx in range(self.num_prefix_chunks):
+ chunk_starts = self.prefix_chunk_starts[idx]
+ chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
+ chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
+ num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
+
+ chunk_kv_indices = torch.empty(
+ num_chunk_tokens, dtype=torch.int32, device=device
+ )
+
+ create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
+ self.req_to_token_pool.req_to_token,
+ self.req_pool_indices,
+ chunk_starts,
+ chunk_seq_lens,
+ chunk_cu_seq_lens,
+ chunk_kv_indices,
+ self.req_to_token_pool.req_to_token.shape[1],
+ )
+ self.prefix_chunk_kv_indices.append(chunk_kv_indices)
+
+ # Here we suppose the length of each chunk is equal
+ # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
+ # num_prefix_chunks = cdiv(1024, 256) = 4
+ # prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
+ # prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
+ # prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
+ # TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
+ def get_prefix_chunk_seq_lens(
+ self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
+ ):
+ device = prefix_lens.device
+ prefix_chunk_starts = (
+ torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
+ .unsqueeze(1)
+ .expand(-1, self.batch_size)
+ * prefix_chunk_len
+ )
+ prefix_chunk_ends = torch.min(
+ prefix_lens.unsqueeze(0),
+ prefix_chunk_starts + prefix_chunk_len,
+ ).to(torch.int32)
+
+ prefix_chunk_seq_lens = (
+ (prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
+ )
+
+ return prefix_chunk_starts, prefix_chunk_seq_lens
+
+ # Called before each attention module if using chunked kv cache for prefill
+ # Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
+ def prepare_chunked_prefix_cache_info(self, device: torch.device):
+
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
+
+ assert isinstance(
+ self.token_to_kv_pool, MLATokenToKVPool
+ ), "Currently chunked prefix cache can only be used by Deepseek models"
+
+ if self.prefix_chunk_len is not None:
+ # Chunked kv cache info already prepared by prior modules
+ return
+
+ self.prefix_chunk_idx = -1
+
+ # chunk_capacity is the maximum number of tokens in each chunk
+ chunk_capacity = self.get_max_chunk_capacity()
+ self.prefix_chunk_len = chunk_capacity // self.batch_size
+
+ self.num_prefix_chunks = (
+ max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
+ ) // self.prefix_chunk_len
+
+ # Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
+ prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
+ self.get_prefix_chunk_seq_lens(
+ self.extend_prefix_lens,
+ self.num_prefix_chunks,
+ self.prefix_chunk_len,
+ )
+ )
+ _, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
+ torch.tensor(self.extend_prefix_lens_cpu),
+ self.num_prefix_chunks,
+ self.prefix_chunk_len,
+ )
+ self.prefix_chunk_starts = prefix_chunk_starts_cuda
+ self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
+
+ # Metadata for attention backend
+ self.prefix_chunk_cu_seq_lens = torch.zeros(
+ self.num_prefix_chunks,
+ self.batch_size + 1,
+ device=device,
+ dtype=torch.int32,
+ )
+ self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
+ dim=1
+ ).to(torch.int32)
+ self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
+ dim=1
+ ).values.tolist()
+
+ self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
+ assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
+
+ # Precompute the kv indices for each chunk
+ self.prepare_chunked_kv_indices(device)
+
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
@@ -561,3 +705,40 @@ def compute_position_torch(
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
+
+
+@triton.jit
+def create_chunked_prefix_cache_kv_indices(
+ req_to_token_ptr, # (max_batch, max_context_len,)
+ req_pool_indices_ptr, # (batch_size,)
+ chunk_start_idx_ptr, # (batch_size,)
+ chunk_seq_lens_ptr, # (batch_size,)
+ chunk_cu_seq_lens_ptr, # (batch_size + 1,)
+ chunk_kv_indices_ptr, # (num_chunk_tokens,)
+ req_to_token_ptr_stride: tl.constexpr,
+):
+ BLOCK_SIZE: tl.constexpr = 512
+ pid = tl.program_id(axis=0)
+
+ # find the req pool idx, this is for batch to token
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
+ chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
+
+ # get the token positions of current chunk
+ chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
+ chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
+
+ num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
+ for i in range(num_loop):
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
+ mask = offset < chunk_seq_len
+ data = tl.load(
+ req_to_token_ptr
+ + req_pool_index * req_to_token_ptr_stride
+ + chunk_start_pos
+ + offset,
+ mask=mask,
+ )
+ tl.store(
+ chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
+ )
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index fd84339d7..ca9ffbab2 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -167,6 +167,7 @@ class ModelRunner:
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
+ "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"use_mla_backend": self.use_mla_backend,
}
)
@@ -318,6 +319,16 @@ class ModelRunner:
if server_args.enable_deepep_moe:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
+ if not self.use_mla_backend:
+ logger.info("Disable chunked prefix cache for non-MLA backend.")
+ server_args.disable_chunked_prefix_cache = True
+ elif self.page_size > 1:
+ logger.info("Disable chunked prefix cache when page size > 1.")
+ server_args.disable_chunked_prefix_cache = True
+
+ if not server_args.disable_chunked_prefix_cache:
+ logger.info("Chunked prefix cache is turned on.")
+
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index df2d6769f..533c3169c 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -18,6 +18,7 @@
import logging
import os
+from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
@@ -78,7 +79,7 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
- from sgl_kernel import awq_dequantize, bmm_fp8
+ from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
else:
@@ -94,6 +95,19 @@ expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
+class AttnForwardMethod(IntEnum):
+
+ # Use multi-head attention
+ MHA = auto()
+
+ # Use absorbed multi-latent attention
+ MLA = auto()
+
+ # Use multi-head attention, but with KV cache chunked.
+ # This method can avoid OOM when prefix lengths are long.
+ MHA_CHUNKED_KV = auto()
+
+
class DeepseekV2MLP(nn.Module):
def __init__(
self,
@@ -694,30 +708,54 @@ class DeepseekV2AttentionMLA(nn.Module):
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
+ self.disable_chunked_prefix_cache = global_server_args_dict[
+ "disable_chunked_prefix_cache"
+ ]
self.attention_backend = global_server_args_dict["attention_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
- def no_absorb(self, forward_batch: ForwardBatch) -> bool:
+ # TODO: Design a finer way to determine the threshold
+ self.chunked_prefix_cache_threshold = 8192
+
+ def dispatch_attn_forward_method(
+ self, forward_batch: ForwardBatch
+ ) -> AttnForwardMethod:
if self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill
- return (
+ if (
not self.flashinfer_mla_disable_ragged
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
- )
+ ):
+ return AttnForwardMethod.MHA
+ else:
+ return AttnForwardMethod.MLA
elif self.attention_backend == "fa3":
- # Flash Attention: Keep absorbing for all extend/decode
- return False
+ # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
+ if (
+ forward_batch.forward_mode.is_extend()
+ and not self.disable_chunked_prefix_cache
+ and not forward_batch.forward_mode.is_target_verify()
+ and not forward_batch.forward_mode.is_draft_extend()
+ and sum(forward_batch.extend_prefix_lens_cpu)
+ >= self.chunked_prefix_cache_threshold
+ ):
+ return AttnForwardMethod.MHA_CHUNKED_KV
+ else:
+ return AttnForwardMethod.MLA
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
- return (
+ if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
- )
+ ):
+ return AttnForwardMethod.MHA
+ else:
+ return AttnForwardMethod.MLA
def forward(
self,
@@ -731,8 +769,14 @@ class DeepseekV2AttentionMLA(nn.Module):
), "short-circuiting allreduce will lead to hangs"
return hidden_states
- if self.no_absorb(forward_batch):
+ attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
+
+ if attn_forward_method == AttnForwardMethod.MHA:
return self.forward_normal(positions, hidden_states, forward_batch)
+ elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
+ return self.forward_normal_chunked_kv(
+ positions, hidden_states, forward_batch
+ )
else:
if _is_hip:
if (
@@ -1007,6 +1051,127 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
+ def _chunked_prefix_attn_mha(
+ self,
+ q: torch.Tensor,
+ accum_output: torch.Tensor,
+ accum_lse: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+
+ assert forward_batch.num_prefix_chunks is not None
+ for i in range(forward_batch.num_prefix_chunks):
+ forward_batch.set_prefix_chunk_idx(i)
+
+ # Fetch latent cache from memory pool with precomputed chunked kv indices
+ latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
+ self.attn_mha.layer_id
+ )
+ latent_cache = latent_cache_buf[
+ forward_batch.prefix_chunk_kv_indices[i]
+ ].contiguous()
+
+ kv_a_normed, k_pe = latent_cache.split(
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+ )
+ kv_a_normed = kv_a_normed.squeeze(1).contiguous()
+ kv = self.kv_b_proj(kv_a_normed)[0]
+ kv = kv.view(
+ -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
+ )
+ v = kv[..., self.qk_nope_head_dim :]
+ k_nope = kv[..., : self.qk_nope_head_dim]
+
+ k = torch.empty(
+ (
+ k_nope.shape[0],
+ self.num_local_heads,
+ self.qk_nope_head_dim + self.qk_rope_head_dim,
+ ),
+ dtype=v.dtype,
+ device=v.device,
+ )
+ k[..., : self.qk_nope_head_dim] = k_nope
+ k[..., self.qk_nope_head_dim :] = k_pe
+
+ output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
+ lse = torch.transpose(lse, 0, 1).contiguous()
+ tmp_output = torch.empty_like(accum_output)
+ tmp_lse = torch.empty_like(accum_lse)
+ merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
+ accum_output, accum_lse = tmp_output, tmp_lse
+
+ return accum_output
+
+ def forward_normal_chunked_kv(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+ # In normal mha, the k and v tensors will become overly large when the prefix length is long.
+ # To avoid this, we split the kv cache into chunks and process them one after another.
+ # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
+ # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
+ # will be helpful for understanding the purpose of this function.
+
+ # First do normal mha forward to get output for extended part
+ if self.q_lora_rank is not None:
+ q = self.q_a_proj(hidden_states)[0]
+ q = self.q_a_layernorm(q)
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
+ else:
+ q = self.q_proj(hidden_states)[0].view(
+ -1, self.num_local_heads, self.qk_head_dim
+ )
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+ latent_cache = latent_cache.unsqueeze(1)
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
+ kv = self.kv_b_proj(kv_a)[0]
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
+ k_nope = kv[..., : self.qk_nope_head_dim]
+ v = kv[..., self.qk_nope_head_dim :]
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
+
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
+ q[..., self.qk_nope_head_dim :] = q_pe
+ k = torch.empty_like(q)
+ k[..., : self.qk_nope_head_dim] = k_nope
+ k[..., self.qk_nope_head_dim :] = k_pe
+
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
+
+ # Save latent cache
+ forward_batch.token_to_kv_pool.set_kv_buffer(
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
+ )
+
+ # Do mha for extended part without prefix
+ forward_batch.set_attn_attend_prefix_cache(False)
+ attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
+ lse = torch.transpose(lse, 0, 1).contiguous()
+
+ # Do mha attention with chunked prefix cache if there are any sequence with prefix
+ if any(forward_batch.extend_prefix_lens_cpu):
+ # Only initialize the info once
+ if forward_batch.num_prefix_chunks is None:
+ forward_batch.prepare_chunked_prefix_cache_info(q.device)
+
+ forward_batch.set_attn_attend_prefix_cache(True)
+ attn_output = self._chunked_prefix_attn_mha(
+ q=q,
+ accum_output=attn_output,
+ accum_lse=lse,
+ forward_batch=forward_batch,
+ )
+
+ attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
+ output, _ = self.o_proj(attn_output)
+ return output
+
class DeepseekV2DecoderLayer(nn.Module):
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 6d78b654a..f9878d2c5 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -186,6 +186,7 @@ class ServerArgs:
warmups: Optional[str] = None
n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False
+ disable_chunked_prefix_cache: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
@@ -1130,6 +1131,11 @@ class ServerArgs:
action="store_true",
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
)
+ parser.add_argument(
+ "--disable-chunked-prefix-cache",
+ action="store_true",
+ help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
+ )
# Server warmups
parser.add_argument(
diff --git a/python/sglang/test/attention/test_prefix_chunk_info.py b/python/sglang/test/attention/test_prefix_chunk_info.py
new file mode 100644
index 000000000..2b85b695b
--- /dev/null
+++ b/python/sglang/test/attention/test_prefix_chunk_info.py
@@ -0,0 +1,224 @@
+import unittest
+
+import torch
+
+from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
+from sglang.test.test_utils import CustomTestCase
+
+TEST_CASES = [
+ # Sequence with same prefix lens
+ {
+ "batch_size": 3,
+ "prefix_lens": [64, 64, 64],
+ "max_chunk_capacity": 48,
+ "prefix_chunk_len": 16,
+ "num_prefix_chunks": 4,
+ "prefix_chunk_starts": torch.tensor(
+ [
+ [0, 0, 0],
+ [16, 16, 16],
+ [32, 32, 32],
+ [48, 48, 48],
+ ],
+ dtype=torch.int32,
+ ),
+ "prefix_chunk_seq_lens": torch.tensor(
+ [
+ [16, 16, 16],
+ [16, 16, 16],
+ [16, 16, 16],
+ [16, 16, 16],
+ ],
+ dtype=torch.int32,
+ ),
+ },
+ # Sequence with different prefix lens
+ {
+ "batch_size": 4,
+ "prefix_lens": [16, 32, 48, 64],
+ "max_chunk_capacity": 64,
+ "prefix_chunk_len": 16,
+ "num_prefix_chunks": 4,
+ "prefix_chunk_starts": torch.tensor(
+ [
+ [0, 0, 0, 0],
+ [16, 16, 16, 16],
+ [32, 32, 32, 32],
+ [48, 48, 48, 48],
+ ],
+ dtype=torch.int32,
+ ),
+ "prefix_chunk_seq_lens": torch.tensor(
+ [
+ [16, 16, 16, 16],
+ [0, 16, 16, 16],
+ [0, 0, 16, 16],
+ [0, 0, 0, 16],
+ ],
+ dtype=torch.int32,
+ ),
+ },
+ # Sequence with irregular shapes
+ {
+ "batch_size": 2,
+ "prefix_lens": [1, 64],
+ "max_chunk_capacity": 31,
+ "prefix_chunk_len": 15,
+ "num_prefix_chunks": 5,
+ "prefix_chunk_starts": torch.tensor(
+ [
+ [0, 0],
+ [15, 15],
+ [30, 30],
+ [45, 45],
+ [60, 60],
+ ],
+ dtype=torch.int32,
+ ),
+ "prefix_chunk_seq_lens": torch.tensor(
+ [
+ [1, 15],
+ [0, 15],
+ [0, 15],
+ [0, 15],
+ [0, 4],
+ ],
+ dtype=torch.int32,
+ ),
+ },
+]
+
+
+class MockForwardBatch(ForwardBatch):
+ def __init__(self, max_chunk_capacity: int, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.max_chunk_capacity = max_chunk_capacity
+
+ def get_max_chunk_capacity(self):
+ return self.max_chunk_capacity
+
+
+class MockReqToTokenPool:
+ def __init__(self, batch_size, seq_len, device):
+ self.req_to_token = (
+ torch.arange(batch_size * seq_len, device=device)
+ .reshape(batch_size, seq_len)
+ .to(torch.int32)
+ )
+
+
+# Test correctness of triton kernel for computing kv indices
+def check_kv_indices(forward_batch):
+ for i in range(forward_batch.num_prefix_chunks):
+ computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i]
+ req_to_token = forward_batch.req_to_token_pool.req_to_token[
+ : forward_batch.batch_size, :
+ ]
+ ref_kv_indices = torch.empty(
+ forward_batch.prefix_chunk_num_tokens[i],
+ dtype=torch.int32,
+ device=computed_kv_indices.device,
+ )
+ running_ptr = 0
+ for j in range(forward_batch.batch_size):
+ seq_start = forward_batch.prefix_chunk_starts[i, j].item()
+ seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item()
+ ref_kv_indices[running_ptr : running_ptr + seq_len].copy_(
+ req_to_token[j, seq_start : seq_start + seq_len]
+ )
+ running_ptr += seq_len
+ assert torch.allclose(computed_kv_indices, ref_kv_indices)
+
+
+@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
+class TestPrefixChunkInfo(CustomTestCase):
+ def setUp(self):
+ # Common test parameters
+ self.num_local_heads = 128
+ self.kv_lora_rank = 512
+ self.qk_rope_head_dim = 64
+ self.device = torch.device("cuda")
+ self.dtype = torch.bfloat16
+ self.extend_len = 64
+ self.max_bs = 4
+ self.max_seq_len = 128
+
+ # req_to_token_pool
+ self.req_to_token_pool = MockReqToTokenPool(
+ self.max_bs,
+ self.max_seq_len,
+ self.device,
+ )
+
+ # token_to_kv_pool
+ self.token_to_kv_pool = MLATokenToKVPool(
+ size=self.max_bs * self.max_seq_len,
+ page_size=1, # only consider page=1 for unit test
+ dtype=self.dtype,
+ kv_lora_rank=self.kv_lora_rank,
+ qk_rope_head_dim=self.qk_rope_head_dim,
+ layer_num=1, # only consider layer=1 for unit test
+ device=self.device,
+ enable_memory_saver=False,
+ )
+
+ def test_prefix_chunk_info(self):
+ """Test the standard extend operation."""
+
+ for test_case in TEST_CASES:
+ print(
+ f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}"
+ )
+ batch_size = test_case["batch_size"]
+ prefix_lens_cpu = test_case["prefix_lens"]
+ assert len(prefix_lens_cpu) == batch_size
+ prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device)
+ max_chunk_capacity = test_case["max_chunk_capacity"]
+ seq_lens_cpu = [
+ self.extend_len + prefix_lens_cpu[i] for i in range(batch_size)
+ ]
+ seq_lens = torch.tensor(seq_lens_cpu, device=self.device)
+
+ # Create forward batch
+ # input_ids and out_cache_loc are dummy tensors in this test
+ forward_batch = MockForwardBatch(
+ max_chunk_capacity=max_chunk_capacity,
+ batch_size=batch_size,
+ input_ids=torch.randint(
+ 0, 100, (batch_size, self.extend_len), device=self.device
+ ),
+ out_cache_loc=torch.arange(
+ self.max_bs * self.max_seq_len - batch_size * self.extend_len,
+ self.max_bs * self.max_seq_len,
+ device=self.device,
+ ),
+ seq_lens_sum=sum(seq_lens_cpu),
+ forward_mode=ForwardMode.EXTEND,
+ req_pool_indices=torch.arange(batch_size, device=self.device),
+ seq_lens=seq_lens,
+ seq_lens_cpu=seq_lens_cpu,
+ extend_prefix_lens=prefix_lens,
+ extend_prefix_lens_cpu=prefix_lens_cpu,
+ )
+ forward_batch.req_to_token_pool = self.req_to_token_pool
+ forward_batch.token_to_kv_pool = self.token_to_kv_pool
+
+ forward_batch.prepare_chunked_prefix_cache_info(self.device)
+ assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity
+ assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"]
+ assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"]
+ assert torch.allclose(
+ forward_batch.prefix_chunk_starts,
+ test_case["prefix_chunk_starts"].to(self.device),
+ )
+ assert torch.allclose(
+ forward_batch.prefix_chunk_seq_lens,
+ test_case["prefix_chunk_seq_lens"].to(self.device),
+ )
+
+ check_kv_indices(forward_batch)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py
index 30d6c7f39..90b462aaa 100644
--- a/test/srt/test_fa3.py
+++ b/test/srt/test_fa3.py
@@ -7,7 +7,6 @@ import torch
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
- DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
@@ -19,7 +18,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p
"""
# Change to your own model if testing model is not public.
MODEL_USED_FOR_TEST = DEFAULT_MODEL_NAME_FOR_TEST
-MODEL_USED_FOR_TEST_MLA = DEFAULT_MLA_MODEL_NAME_FOR_TEST
+MODEL_USED_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
# Setting data path to None uses default data path in few_shot_gsm8k eval test.
DATA_PATH = None
@@ -174,5 +173,57 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
self.assertGreater(avg_spec_accept_length, 1.5)
+class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
+ """Test FlashAttention3 with speculative decode enabled."""
+
+ model = MODEL_USED_FOR_TEST_MLA
+
+ @classmethod
+ def get_server_args(cls):
+ args = super().get_server_args()
+ args.extend(
+ [
+ "--cuda-graph-max-bs",
+ "2",
+ "--speculative-algorithm",
+ "EAGLE",
+ "--speculative-draft",
+ "lmsys/sglang-ci-dsv3-test-NextN",
+ "--speculative-num-steps",
+ "3",
+ "--speculative-eagle-topk",
+ "1",
+ "--speculative-num-draft-tokens",
+ "3",
+ ]
+ )
+ return args
+
+ def test_gsm8k(self):
+ """
+ Override the test_gsm8k to further test for average speculative accept length.
+ """
+ requests.get(self.base_url + "/flush_cache")
+
+ args = SimpleNamespace(
+ num_shots=5,
+ data_path=DATA_PATH,
+ num_questions=200,
+ max_new_tokens=512,
+ parallel=128,
+ host="http://127.0.0.1",
+ port=int(self.base_url.split(":")[-1]),
+ )
+ metrics = run_eval_few_shot_gsm8k(args)
+ print(metrics)
+
+ self.assertGreater(metrics["accuracy"], 0.60)
+
+ server_info = requests.get(self.base_url + "/get_server_info")
+ avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
+ print(f"{avg_spec_accept_length=}")
+ self.assertGreater(avg_spec_accept_length, 1.5)
+
+
if __name__ == "__main__":
unittest.main()