diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7ae15f1f2..e4dc1b878 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -195,3 +195,4 @@ Please consult the documentation below to learn more about the parameters you ma * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. * `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.** * `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend. +* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 5375008b7..8fc71e03c 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -92,13 +92,15 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes. +- **Chunked Prefix Cache**: Chunked prefix cache optimization can increase throughput by cutting prefix cache into chunks, processing them with multi-head attention and merging their states. Its improvement can be significant when doing chunked prefill on long sequences. Currently this optimization is only available for FlashAttention3 backend. + Overall, with these optimizations, we have achieved up to **7x** acceleration in output throughput compared to the previous version.

Multi-head Latent Attention for DeepSeek Series Models

-**Usage**: MLA optimization is enabled by default, to disable, use `--disable-mla`. +**Usage**: MLA optimization is enabled by default. To disable MLA usage, use `--disable-mla`. To disable chunked prefix cache feature for mla, use `disable-chunked-prefix-cache`. **Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details. diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 040393c17..26e54546d 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner -from sgl_kernel.flash_attn import flash_attn_with_kvcache +from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache @dataclass @@ -593,41 +593,87 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, ) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) else: - # Do absorbed multi-latent attention - kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - k_rope = kv_cache[:, :, layer.v_head_dim :] - c_kv = kv_cache[:, :, : layer.v_head_dim] - k_rope_cache = k_rope.view( - -1, - self.page_size, - layer.tp_k_head_num, - layer.head_dim - layer.v_head_dim, - ) - c_kv_cache = c_kv.view( - -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim - ) - q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - q_nope = q_all[:, :, : layer.v_head_dim] - q_rope = q_all[:, :, layer.v_head_dim :] - o = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope_cache, - v_cache=c_kv_cache, - qv=q_nope, - page_table=page_table, - cache_seqlens=cache_seqlens, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, - max_seqlen_q=max_seqlen_q, - softmax_scale=layer.scaling, - causal=True, - softcap=layer.logit_cap, - k_descale=k_descale, - v_descale=v_descale, - ) + if ( + not global_server_args_dict["disable_chunked_prefix_cache"] + and forward_batch.attn_attend_prefix_cache is not None + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + # Do multi-head attention with chunked prefix cache - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + if forward_batch.attn_attend_prefix_cache: + # MHA for chunked prefix kv cache when running model with MLA + assert forward_batch.prefix_chunk_idx is not None + assert forward_batch.prefix_chunk_cu_seq_lens is not None + assert forward_batch.prefix_chunk_max_seq_lens is not None + + chunk_idx = forward_batch.prefix_chunk_idx + assert chunk_idx >= 0 + + output, lse, *rest = flash_attn_varlen_func( + q=q.view(-1, layer.tp_q_head_num, layer.head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim), + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], + max_seqlen_q=metadata.max_seq_len_q, + max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], + softmax_scale=layer.scaling, + causal=False, + return_softmax_lse=True, + ) + else: + # MHA for extend part of sequence without attending prefix kv cache + output, lse, *rest = flash_attn_varlen_func( + q=q.view(-1, layer.tp_q_head_num, layer.head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim), + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k=metadata.cu_seqlens_q, + max_seqlen_q=metadata.max_seq_len_q, + max_seqlen_k=metadata.max_seq_len_q, + softmax_scale=layer.scaling, + causal=True, + return_softmax_lse=True, + ) + return output, lse + else: + # Do absorbed multi-latent attention + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + k_rope = kv_cache[:, :, layer.v_head_dim :] + c_kv = kv_cache[:, :, : layer.v_head_dim] + k_rope_cache = k_rope.view( + -1, + self.page_size, + layer.tp_k_head_num, + layer.head_dim - layer.v_head_dim, + ) + c_kv_cache = c_kv.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + o = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=True, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) def forward_decode( self, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fa9f40112..6bfd69307 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -83,6 +83,7 @@ global_server_args_dict = { "chunked_prefill_size": ServerArgs.chunked_prefill_size, "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, + "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache, } logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index ba861b850..91cec3210 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -181,6 +181,28 @@ class ForwardBatch: extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None + # For MLA chunked prefix cache used in chunked prefill + # Tell attention backend whether the kv cache needs to be attended in current pass + attn_attend_prefix_cache: Optional[bool] = None + # Number of prefix cache chunks + num_prefix_chunks: Optional[int] = None + # Index of current chunk, used by attention backend + prefix_chunk_idx: Optional[int] = None + # Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity + prefix_chunk_len: Optional[int] = None + # Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size) + prefix_chunk_starts: Optional[torch.Tensor] = None + # Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size) + prefix_chunk_seq_lens: Optional[torch.Tensor] = None + # Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1) + prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None + # Max lengths of prefix cache for each chunk, (num_prefix_chunks,) + prefix_chunk_max_seq_lens: Optional[List[int]] = None + # Number of tokens in each prefix cache chunk, (num_prefix_chunks,) + prefix_chunk_num_tokens: Optional[List[int]] = None + # KV Indices for each chunk + prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None + # For multimodal mm_inputs: Optional[List[MultimodalInputs]] = None @@ -484,6 +506,128 @@ class ForwardBatch: ) self.mrope_positions = self.mrope_positions.to(torch.int64) + def get_max_chunk_capacity(self): + # Maximum number of tokens in each chunk + # TODO: Should be changed to a better value, maybe passed through server args + return 128 * 1024 + + def set_prefix_chunk_idx(self, idx: int): + self.prefix_chunk_idx = idx + + def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool): + self.attn_attend_prefix_cache = attn_attend_prefix_cache + + def prepare_chunked_kv_indices(self, device: torch.device): + self.prefix_chunk_kv_indices = [] + for idx in range(self.num_prefix_chunks): + chunk_starts = self.prefix_chunk_starts[idx] + chunk_seq_lens = self.prefix_chunk_seq_lens[idx] + chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx] + num_chunk_tokens = self.prefix_chunk_num_tokens[idx] + + chunk_kv_indices = torch.empty( + num_chunk_tokens, dtype=torch.int32, device=device + ) + + create_chunked_prefix_cache_kv_indices[(self.batch_size,)]( + self.req_to_token_pool.req_to_token, + self.req_pool_indices, + chunk_starts, + chunk_seq_lens, + chunk_cu_seq_lens, + chunk_kv_indices, + self.req_to_token_pool.req_to_token.shape[1], + ) + self.prefix_chunk_kv_indices.append(chunk_kv_indices) + + # Here we suppose the length of each chunk is equal + # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256 + # num_prefix_chunks = cdiv(1024, 256) = 4 + # prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]] + # prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]] + # prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]] + # TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently. + def get_prefix_chunk_seq_lens( + self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int + ): + device = prefix_lens.device + prefix_chunk_starts = ( + torch.arange(num_prefix_chunks, device=device, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, self.batch_size) + * prefix_chunk_len + ) + prefix_chunk_ends = torch.min( + prefix_lens.unsqueeze(0), + prefix_chunk_starts + prefix_chunk_len, + ).to(torch.int32) + + prefix_chunk_seq_lens = ( + (prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32) + ) + + return prefix_chunk_starts, prefix_chunk_seq_lens + + # Called before each attention module if using chunked kv cache for prefill + # Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py + def prepare_chunked_prefix_cache_info(self, device: torch.device): + + from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool + + assert isinstance( + self.token_to_kv_pool, MLATokenToKVPool + ), "Currently chunked prefix cache can only be used by Deepseek models" + + if self.prefix_chunk_len is not None: + # Chunked kv cache info already prepared by prior modules + return + + self.prefix_chunk_idx = -1 + + # chunk_capacity is the maximum number of tokens in each chunk + chunk_capacity = self.get_max_chunk_capacity() + self.prefix_chunk_len = chunk_capacity // self.batch_size + + self.num_prefix_chunks = ( + max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1 + ) // self.prefix_chunk_len + + # Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu. + prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = ( + self.get_prefix_chunk_seq_lens( + self.extend_prefix_lens, + self.num_prefix_chunks, + self.prefix_chunk_len, + ) + ) + _, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens( + torch.tensor(self.extend_prefix_lens_cpu), + self.num_prefix_chunks, + self.prefix_chunk_len, + ) + self.prefix_chunk_starts = prefix_chunk_starts_cuda + self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda + + # Metadata for attention backend + self.prefix_chunk_cu_seq_lens = torch.zeros( + self.num_prefix_chunks, + self.batch_size + 1, + device=device, + dtype=torch.int32, + ) + self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum( + dim=1 + ).to(torch.int32) + self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max( + dim=1 + ).values.tolist() + + self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist() + assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity() + + # Precompute the kv indices for each chunk + self.prepare_chunked_kv_indices(device) + def compute_position_triton( extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum @@ -561,3 +705,40 @@ def compute_position_torch( @torch.compile(dynamic=True, backend=get_compiler_backend()) def clamp_position(seq_lens): return torch.clamp((seq_lens - 1), min=0).to(torch.int64) + + +@triton.jit +def create_chunked_prefix_cache_kv_indices( + req_to_token_ptr, # (max_batch, max_context_len,) + req_pool_indices_ptr, # (batch_size,) + chunk_start_idx_ptr, # (batch_size,) + chunk_seq_lens_ptr, # (batch_size,) + chunk_cu_seq_lens_ptr, # (batch_size + 1,) + chunk_kv_indices_ptr, # (num_chunk_tokens,) + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + + # find the req pool idx, this is for batch to token + req_pool_index = tl.load(req_pool_indices_ptr + pid) + chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid) + + # get the token positions of current chunk + chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32) + chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32) + + num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < chunk_seq_len + data = tl.load( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + chunk_start_pos + + offset, + mask=mask, + ) + tl.store( + chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fd84339d7..ca9ffbab2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -167,6 +167,7 @@ class ModelRunner: "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "n_share_experts_fusion": server_args.n_share_experts_fusion, "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion, + "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache, "use_mla_backend": self.use_mla_backend, } ) @@ -318,6 +319,16 @@ class ModelRunner: if server_args.enable_deepep_moe: logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}") + if not self.use_mla_backend: + logger.info("Disable chunked prefix cache for non-MLA backend.") + server_args.disable_chunked_prefix_cache = True + elif self.page_size > 1: + logger.info("Disable chunked prefix cache when page size > 1.") + server_args.disable_chunked_prefix_cache = True + + if not server_args.disable_chunked_prefix_cache: + logger.info("Chunked prefix cache is turned on.") + def init_torch_distributed(self): logger.info("Init torch distributed begin.") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index df2d6769f..533c3169c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -18,6 +18,7 @@ import logging import os +from enum import IntEnum, auto from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -78,7 +79,7 @@ _is_hip = is_hip() _is_cuda = is_cuda() if _is_cuda: - from sgl_kernel import awq_dequantize, bmm_fp8 + from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher else: @@ -94,6 +95,19 @@ expert_distribution_recorder = ExpertDistributionRecorder() logger = logging.getLogger(__name__) +class AttnForwardMethod(IntEnum): + + # Use multi-head attention + MHA = auto() + + # Use absorbed multi-latent attention + MLA = auto() + + # Use multi-head attention, but with KV cache chunked. + # This method can avoid OOM when prefix lengths are long. + MHA_CHUNKED_KV = auto() + + class DeepseekV2MLP(nn.Module): def __init__( self, @@ -694,30 +708,54 @@ class DeepseekV2AttentionMLA(nn.Module): self.flashinfer_mla_disable_ragged = global_server_args_dict[ "flashinfer_mla_disable_ragged" ] + self.disable_chunked_prefix_cache = global_server_args_dict[ + "disable_chunked_prefix_cache" + ] self.attention_backend = global_server_args_dict["attention_backend"] self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" - def no_absorb(self, forward_batch: ForwardBatch) -> bool: + # TODO: Design a finer way to determine the threshold + self.chunked_prefix_cache_threshold = 8192 + + def dispatch_attn_forward_method( + self, forward_batch: ForwardBatch + ) -> AttnForwardMethod: if self.attention_backend == "flashinfer": # Flashinfer MLA: Do not absorb when enabling ragged prefill - return ( + if ( not self.flashinfer_mla_disable_ragged and forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() and sum(forward_batch.extend_prefix_lens_cpu) == 0 - ) + ): + return AttnForwardMethod.MHA + else: + return AttnForwardMethod.MLA elif self.attention_backend == "fa3": - # Flash Attention: Keep absorbing for all extend/decode - return False + # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences. + if ( + forward_batch.forward_mode.is_extend() + and not self.disable_chunked_prefix_cache + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + and sum(forward_batch.extend_prefix_lens_cpu) + >= self.chunked_prefix_cache_threshold + ): + return AttnForwardMethod.MHA_CHUNKED_KV + else: + return AttnForwardMethod.MLA else: # Triton: Use normal computation for prefill and use weight absorption for extend/decode - return ( + if ( forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() and sum(forward_batch.extend_prefix_lens_cpu) == 0 - ) + ): + return AttnForwardMethod.MHA + else: + return AttnForwardMethod.MLA def forward( self, @@ -731,8 +769,14 @@ class DeepseekV2AttentionMLA(nn.Module): ), "short-circuiting allreduce will lead to hangs" return hidden_states - if self.no_absorb(forward_batch): + attn_forward_method = self.dispatch_attn_forward_method(forward_batch) + + if attn_forward_method == AttnForwardMethod.MHA: return self.forward_normal(positions, hidden_states, forward_batch) + elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV: + return self.forward_normal_chunked_kv( + positions, hidden_states, forward_batch + ) else: if _is_hip: if ( @@ -1007,6 +1051,127 @@ class DeepseekV2AttentionMLA(nn.Module): return output + def _chunked_prefix_attn_mha( + self, + q: torch.Tensor, + accum_output: torch.Tensor, + accum_lse: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + + assert forward_batch.num_prefix_chunks is not None + for i in range(forward_batch.num_prefix_chunks): + forward_batch.set_prefix_chunk_idx(i) + + # Fetch latent cache from memory pool with precomputed chunked kv indices + latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer( + self.attn_mha.layer_id + ) + latent_cache = latent_cache_buf[ + forward_batch.prefix_chunk_kv_indices[i] + ].contiguous() + + kv_a_normed, k_pe = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + kv_a_normed = kv_a_normed.squeeze(1).contiguous() + kv = self.kv_b_proj(kv_a_normed)[0] + kv = kv.view( + -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim + ) + v = kv[..., self.qk_nope_head_dim :] + k_nope = kv[..., : self.qk_nope_head_dim] + + k = torch.empty( + ( + k_nope.shape[0], + self.num_local_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + ), + dtype=v.dtype, + device=v.device, + ) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + + output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) + lse = torch.transpose(lse, 0, 1).contiguous() + tmp_output = torch.empty_like(accum_output) + tmp_lse = torch.empty_like(accum_lse) + merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse) + accum_output, accum_lse = tmp_output, tmp_lse + + return accum_output + + def forward_normal_chunked_kv( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + # In normal mha, the k and v tensors will become overly large when the prefix length is long. + # To avoid this, we split the kv cache into chunks and process them one after another. + # Since mha is compute friendly, the for loop induced here will not introduce significant overhead. + # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py + # will be helpful for understanding the purpose of this function. + + # First do normal mha forward to get output for extended part + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope = kv[..., : self.qk_nope_head_dim] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q[..., self.qk_nope_head_dim :] = q_pe + k = torch.empty_like(q) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + + latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) + latent_cache[:, :, self.kv_lora_rank :] = k_pe + + # Save latent cache + forward_batch.token_to_kv_pool.set_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, latent_cache, None + ) + + # Do mha for extended part without prefix + forward_batch.set_attn_attend_prefix_cache(False) + attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) + lse = torch.transpose(lse, 0, 1).contiguous() + + # Do mha attention with chunked prefix cache if there are any sequence with prefix + if any(forward_batch.extend_prefix_lens_cpu): + # Only initialize the info once + if forward_batch.num_prefix_chunks is None: + forward_batch.prepare_chunked_prefix_cache_info(q.device) + + forward_batch.set_attn_attend_prefix_cache(True) + attn_output = self._chunked_prefix_attn_mha( + q=q, + accum_output=attn_output, + accum_lse=lse, + forward_batch=forward_batch, + ) + + attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + class DeepseekV2DecoderLayer(nn.Module): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6d78b654a..f9878d2c5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -186,6 +186,7 @@ class ServerArgs: warmups: Optional[str] = None n_share_experts_fusion: int = 0 disable_shared_experts_fusion: bool = False + disable_chunked_prefix_cache: bool = False # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -1130,6 +1131,11 @@ class ServerArgs: action="store_true", help="Disable shared experts fusion by setting n_share_experts_fusion to 0.", ) + parser.add_argument( + "--disable-chunked-prefix-cache", + action="store_true", + help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.", + ) # Server warmups parser.add_argument( diff --git a/python/sglang/test/attention/test_prefix_chunk_info.py b/python/sglang/test/attention/test_prefix_chunk_info.py new file mode 100644 index 000000000..2b85b695b --- /dev/null +++ b/python/sglang/test/attention/test_prefix_chunk_info.py @@ -0,0 +1,224 @@ +import unittest + +import torch + +from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.test.test_utils import CustomTestCase + +TEST_CASES = [ + # Sequence with same prefix lens + { + "batch_size": 3, + "prefix_lens": [64, 64, 64], + "max_chunk_capacity": 48, + "prefix_chunk_len": 16, + "num_prefix_chunks": 4, + "prefix_chunk_starts": torch.tensor( + [ + [0, 0, 0], + [16, 16, 16], + [32, 32, 32], + [48, 48, 48], + ], + dtype=torch.int32, + ), + "prefix_chunk_seq_lens": torch.tensor( + [ + [16, 16, 16], + [16, 16, 16], + [16, 16, 16], + [16, 16, 16], + ], + dtype=torch.int32, + ), + }, + # Sequence with different prefix lens + { + "batch_size": 4, + "prefix_lens": [16, 32, 48, 64], + "max_chunk_capacity": 64, + "prefix_chunk_len": 16, + "num_prefix_chunks": 4, + "prefix_chunk_starts": torch.tensor( + [ + [0, 0, 0, 0], + [16, 16, 16, 16], + [32, 32, 32, 32], + [48, 48, 48, 48], + ], + dtype=torch.int32, + ), + "prefix_chunk_seq_lens": torch.tensor( + [ + [16, 16, 16, 16], + [0, 16, 16, 16], + [0, 0, 16, 16], + [0, 0, 0, 16], + ], + dtype=torch.int32, + ), + }, + # Sequence with irregular shapes + { + "batch_size": 2, + "prefix_lens": [1, 64], + "max_chunk_capacity": 31, + "prefix_chunk_len": 15, + "num_prefix_chunks": 5, + "prefix_chunk_starts": torch.tensor( + [ + [0, 0], + [15, 15], + [30, 30], + [45, 45], + [60, 60], + ], + dtype=torch.int32, + ), + "prefix_chunk_seq_lens": torch.tensor( + [ + [1, 15], + [0, 15], + [0, 15], + [0, 15], + [0, 4], + ], + dtype=torch.int32, + ), + }, +] + + +class MockForwardBatch(ForwardBatch): + def __init__(self, max_chunk_capacity: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_chunk_capacity = max_chunk_capacity + + def get_max_chunk_capacity(self): + return self.max_chunk_capacity + + +class MockReqToTokenPool: + def __init__(self, batch_size, seq_len, device): + self.req_to_token = ( + torch.arange(batch_size * seq_len, device=device) + .reshape(batch_size, seq_len) + .to(torch.int32) + ) + + +# Test correctness of triton kernel for computing kv indices +def check_kv_indices(forward_batch): + for i in range(forward_batch.num_prefix_chunks): + computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i] + req_to_token = forward_batch.req_to_token_pool.req_to_token[ + : forward_batch.batch_size, : + ] + ref_kv_indices = torch.empty( + forward_batch.prefix_chunk_num_tokens[i], + dtype=torch.int32, + device=computed_kv_indices.device, + ) + running_ptr = 0 + for j in range(forward_batch.batch_size): + seq_start = forward_batch.prefix_chunk_starts[i, j].item() + seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item() + ref_kv_indices[running_ptr : running_ptr + seq_len].copy_( + req_to_token[j, seq_start : seq_start + seq_len] + ) + running_ptr += seq_len + assert torch.allclose(computed_kv_indices, ref_kv_indices) + + +@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") +class TestPrefixChunkInfo(CustomTestCase): + def setUp(self): + # Common test parameters + self.num_local_heads = 128 + self.kv_lora_rank = 512 + self.qk_rope_head_dim = 64 + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + self.extend_len = 64 + self.max_bs = 4 + self.max_seq_len = 128 + + # req_to_token_pool + self.req_to_token_pool = MockReqToTokenPool( + self.max_bs, + self.max_seq_len, + self.device, + ) + + # token_to_kv_pool + self.token_to_kv_pool = MLATokenToKVPool( + size=self.max_bs * self.max_seq_len, + page_size=1, # only consider page=1 for unit test + dtype=self.dtype, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + layer_num=1, # only consider layer=1 for unit test + device=self.device, + enable_memory_saver=False, + ) + + def test_prefix_chunk_info(self): + """Test the standard extend operation.""" + + for test_case in TEST_CASES: + print( + f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}" + ) + batch_size = test_case["batch_size"] + prefix_lens_cpu = test_case["prefix_lens"] + assert len(prefix_lens_cpu) == batch_size + prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device) + max_chunk_capacity = test_case["max_chunk_capacity"] + seq_lens_cpu = [ + self.extend_len + prefix_lens_cpu[i] for i in range(batch_size) + ] + seq_lens = torch.tensor(seq_lens_cpu, device=self.device) + + # Create forward batch + # input_ids and out_cache_loc are dummy tensors in this test + forward_batch = MockForwardBatch( + max_chunk_capacity=max_chunk_capacity, + batch_size=batch_size, + input_ids=torch.randint( + 0, 100, (batch_size, self.extend_len), device=self.device + ), + out_cache_loc=torch.arange( + self.max_bs * self.max_seq_len - batch_size * self.extend_len, + self.max_bs * self.max_seq_len, + device=self.device, + ), + seq_lens_sum=sum(seq_lens_cpu), + forward_mode=ForwardMode.EXTEND, + req_pool_indices=torch.arange(batch_size, device=self.device), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + extend_prefix_lens=prefix_lens, + extend_prefix_lens_cpu=prefix_lens_cpu, + ) + forward_batch.req_to_token_pool = self.req_to_token_pool + forward_batch.token_to_kv_pool = self.token_to_kv_pool + + forward_batch.prepare_chunked_prefix_cache_info(self.device) + assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity + assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"] + assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"] + assert torch.allclose( + forward_batch.prefix_chunk_starts, + test_case["prefix_chunk_starts"].to(self.device), + ) + assert torch.allclose( + forward_batch.prefix_chunk_seq_lens, + test_case["prefix_chunk_seq_lens"].to(self.device), + ) + + check_kv_indices(forward_batch) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 30d6c7f39..90b462aaa 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -7,7 +7,6 @@ import torch from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -19,7 +18,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p """ # Change to your own model if testing model is not public. MODEL_USED_FOR_TEST = DEFAULT_MODEL_NAME_FOR_TEST -MODEL_USED_FOR_TEST_MLA = DEFAULT_MLA_MODEL_NAME_FOR_TEST +MODEL_USED_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" # Setting data path to None uses default data path in few_shot_gsm8k eval test. DATA_PATH = None @@ -174,5 +173,57 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): self.assertGreater(avg_spec_accept_length, 1.5) +class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): + """Test FlashAttention3 with speculative decode enabled.""" + + model = MODEL_USED_FOR_TEST_MLA + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + "lmsys/sglang-ci-dsv3-test-NextN", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "3", + ] + ) + return args + + def test_gsm8k(self): + """ + Override the test_gsm8k to further test for average speculative accept length. + """ + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=DATA_PATH, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 1.5) + + if __name__ == "__main__": unittest.main()