Support MHA with chunked prefix cache for DeepSeek chunked prefill (#5113)

This commit is contained in:
Baizhou Zhang
2025-04-15 22:01:22 -07:00
committed by GitHub
parent dd83e7e9c3
commit a42736bbb8
10 changed files with 734 additions and 46 deletions

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass
@@ -593,41 +593,87 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
if (
not global_server_args_dict["disable_chunked_prefix_cache"]
and forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
# Do multi-head attention with chunked prefix cache
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
output, lse, *rest = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
output, lse, *rest = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=True,
)
return output, lse
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,

View File

@@ -83,6 +83,7 @@ global_server_args_dict = {
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
}
logger = logging.getLogger(__name__)

View File

@@ -181,6 +181,28 @@ class ForwardBatch:
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether the kv cache needs to be attended in current pass
attn_attend_prefix_cache: Optional[bool] = None
# Number of prefix cache chunks
num_prefix_chunks: Optional[int] = None
# Index of current chunk, used by attention backend
prefix_chunk_idx: Optional[int] = None
# Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
prefix_chunk_len: Optional[int] = None
# Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
prefix_chunk_starts: Optional[torch.Tensor] = None
# Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
prefix_chunk_seq_lens: Optional[torch.Tensor] = None
# Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
# Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
prefix_chunk_max_seq_lens: Optional[List[int]] = None
# Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
prefix_chunk_num_tokens: Optional[List[int]] = None
# KV Indices for each chunk
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
# For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None
@@ -484,6 +506,128 @@ class ForwardBatch:
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def get_max_chunk_capacity(self):
# Maximum number of tokens in each chunk
# TODO: Should be changed to a better value, maybe passed through server args
return 128 * 1024
def set_prefix_chunk_idx(self, idx: int):
self.prefix_chunk_idx = idx
def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
self.attn_attend_prefix_cache = attn_attend_prefix_cache
def prepare_chunked_kv_indices(self, device: torch.device):
self.prefix_chunk_kv_indices = []
for idx in range(self.num_prefix_chunks):
chunk_starts = self.prefix_chunk_starts[idx]
chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
chunk_kv_indices = torch.empty(
num_chunk_tokens, dtype=torch.int32, device=device
)
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
# Here we suppose the length of each chunk is equal
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# num_prefix_chunks = cdiv(1024, 256) = 4
# prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
# prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
# prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
# TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
def get_prefix_chunk_seq_lens(
self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
):
device = prefix_lens.device
prefix_chunk_starts = (
torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, self.batch_size)
* prefix_chunk_len
)
prefix_chunk_ends = torch.min(
prefix_lens.unsqueeze(0),
prefix_chunk_starts + prefix_chunk_len,
).to(torch.int32)
prefix_chunk_seq_lens = (
(prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
)
return prefix_chunk_starts, prefix_chunk_seq_lens
# Called before each attention module if using chunked kv cache for prefill
# Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
def prepare_chunked_prefix_cache_info(self, device: torch.device):
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
assert isinstance(
self.token_to_kv_pool, MLATokenToKVPool
), "Currently chunked prefix cache can only be used by Deepseek models"
if self.prefix_chunk_len is not None:
# Chunked kv cache info already prepared by prior modules
return
self.prefix_chunk_idx = -1
# chunk_capacity is the maximum number of tokens in each chunk
chunk_capacity = self.get_max_chunk_capacity()
self.prefix_chunk_len = chunk_capacity // self.batch_size
self.num_prefix_chunks = (
max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
) // self.prefix_chunk_len
# Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
self.get_prefix_chunk_seq_lens(
self.extend_prefix_lens,
self.num_prefix_chunks,
self.prefix_chunk_len,
)
)
_, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
torch.tensor(self.extend_prefix_lens_cpu),
self.num_prefix_chunks,
self.prefix_chunk_len,
)
self.prefix_chunk_starts = prefix_chunk_starts_cuda
self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
# Metadata for attention backend
self.prefix_chunk_cu_seq_lens = torch.zeros(
self.num_prefix_chunks,
self.batch_size + 1,
device=device,
dtype=torch.int32,
)
self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
dim=1
).to(torch.int32)
self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
dim=1
).values.tolist()
self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
# Precompute the kv indices for each chunk
self.prepare_chunked_kv_indices(device)
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
@@ -561,3 +705,40 @@ def compute_position_torch(
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
@triton.jit
def create_chunked_prefix_cache_kv_indices(
req_to_token_ptr, # (max_batch, max_context_len,)
req_pool_indices_ptr, # (batch_size,)
chunk_start_idx_ptr, # (batch_size,)
chunk_seq_lens_ptr, # (batch_size,)
chunk_cu_seq_lens_ptr, # (batch_size + 1,)
chunk_kv_indices_ptr, # (num_chunk_tokens,)
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token
req_pool_index = tl.load(req_pool_indices_ptr + pid)
chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
# get the token positions of current chunk
chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < chunk_seq_len
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ chunk_start_pos
+ offset,
mask=mask,
)
tl.store(
chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
)

View File

@@ -167,6 +167,7 @@ class ModelRunner:
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"use_mla_backend": self.use_mla_backend,
}
)
@@ -318,6 +319,16 @@ class ModelRunner:
if server_args.enable_deepep_moe:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
if not self.use_mla_backend:
logger.info("Disable chunked prefix cache for non-MLA backend.")
server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache:
logger.info("Chunked prefix cache is turned on.")
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")

View File

@@ -18,6 +18,7 @@
import logging
import os
from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
@@ -78,7 +79,7 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
else:
@@ -94,6 +95,19 @@ expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
class AttnForwardMethod(IntEnum):
# Use multi-head attention
MHA = auto()
# Use absorbed multi-latent attention
MLA = auto()
# Use multi-head attention, but with KV cache chunked.
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto()
class DeepseekV2MLP(nn.Module):
def __init__(
self,
@@ -694,30 +708,54 @@ class DeepseekV2AttentionMLA(nn.Module):
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
self.attention_backend = global_server_args_dict["attention_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
# TODO: Design a finer way to determine the threshold
self.chunked_prefix_cache_threshold = 8192
def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch
) -> AttnForwardMethod:
if self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
if (
not self.flashinfer_mla_disable_ragged
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
)
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
elif self.attention_backend == "fa3":
# Flash Attention: Keep absorbing for all extend/decode
return False
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if (
forward_batch.forward_mode.is_extend()
and not self.disable_chunked_prefix_cache
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu)
>= self.chunked_prefix_cache_threshold
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return AttnForwardMethod.MLA
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
)
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
def forward(
self,
@@ -731,8 +769,14 @@ class DeepseekV2AttentionMLA(nn.Module):
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if self.no_absorb(forward_batch):
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
if attn_forward_method == AttnForwardMethod.MHA:
return self.forward_normal(positions, hidden_states, forward_batch)
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
return self.forward_normal_chunked_kv(
positions, hidden_states, forward_batch
)
else:
if _is_hip:
if (
@@ -1007,6 +1051,127 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
def _chunked_prefix_attn_mha(
self,
q: torch.Tensor,
accum_output: torch.Tensor,
accum_lse: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
assert forward_batch.num_prefix_chunks is not None
for i in range(forward_batch.num_prefix_chunks):
forward_batch.set_prefix_chunk_idx(i)
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
self.attn_mha.layer_id
)
latent_cache = latent_cache_buf[
forward_batch.prefix_chunk_kv_indices[i]
].contiguous()
kv_a_normed, k_pe = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
kv_a_normed = kv_a_normed.squeeze(1).contiguous()
kv = self.kv_b_proj(kv_a_normed)[0]
kv = kv.view(
-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
)
v = kv[..., self.qk_nope_head_dim :]
k_nope = kv[..., : self.qk_nope_head_dim]
k = torch.empty(
(
k_nope.shape[0],
self.num_local_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
),
dtype=v.dtype,
device=v.device,
)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
lse = torch.transpose(lse, 0, 1).contiguous()
tmp_output = torch.empty_like(accum_output)
tmp_lse = torch.empty_like(accum_lse)
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
accum_output, accum_lse = tmp_output, tmp_lse
return accum_output
def forward_normal_chunked_kv(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
# To avoid this, we split the kv cache into chunks and process them one after another.
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
# The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
# will be helpful for understanding the purpose of this function.
# First do normal mha forward to get output for extended part
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
# Do mha for extended part without prefix
forward_batch.set_attn_attend_prefix_cache(False)
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
lse = torch.transpose(lse, 0, 1).contiguous()
# Do mha attention with chunked prefix cache if there are any sequence with prefix
if any(forward_batch.extend_prefix_lens_cpu):
# Only initialize the info once
if forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
forward_batch.set_attn_attend_prefix_cache(True)
attn_output = self._chunked_prefix_attn_mha(
q=q,
accum_output=attn_output,
accum_lse=lse,
forward_batch=forward_batch,
)
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
class DeepseekV2DecoderLayer(nn.Module):

View File

@@ -186,6 +186,7 @@ class ServerArgs:
warmups: Optional[str] = None
n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
@@ -1130,6 +1131,11 @@ class ServerArgs:
action="store_true",
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
)
parser.add_argument(
"--disable-chunked-prefix-cache",
action="store_true",
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
)
# Server warmups
parser.add_argument(

View File

@@ -0,0 +1,224 @@
import unittest
import torch
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.test.test_utils import CustomTestCase
TEST_CASES = [
# Sequence with same prefix lens
{
"batch_size": 3,
"prefix_lens": [64, 64, 64],
"max_chunk_capacity": 48,
"prefix_chunk_len": 16,
"num_prefix_chunks": 4,
"prefix_chunk_starts": torch.tensor(
[
[0, 0, 0],
[16, 16, 16],
[32, 32, 32],
[48, 48, 48],
],
dtype=torch.int32,
),
"prefix_chunk_seq_lens": torch.tensor(
[
[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
],
dtype=torch.int32,
),
},
# Sequence with different prefix lens
{
"batch_size": 4,
"prefix_lens": [16, 32, 48, 64],
"max_chunk_capacity": 64,
"prefix_chunk_len": 16,
"num_prefix_chunks": 4,
"prefix_chunk_starts": torch.tensor(
[
[0, 0, 0, 0],
[16, 16, 16, 16],
[32, 32, 32, 32],
[48, 48, 48, 48],
],
dtype=torch.int32,
),
"prefix_chunk_seq_lens": torch.tensor(
[
[16, 16, 16, 16],
[0, 16, 16, 16],
[0, 0, 16, 16],
[0, 0, 0, 16],
],
dtype=torch.int32,
),
},
# Sequence with irregular shapes
{
"batch_size": 2,
"prefix_lens": [1, 64],
"max_chunk_capacity": 31,
"prefix_chunk_len": 15,
"num_prefix_chunks": 5,
"prefix_chunk_starts": torch.tensor(
[
[0, 0],
[15, 15],
[30, 30],
[45, 45],
[60, 60],
],
dtype=torch.int32,
),
"prefix_chunk_seq_lens": torch.tensor(
[
[1, 15],
[0, 15],
[0, 15],
[0, 15],
[0, 4],
],
dtype=torch.int32,
),
},
]
class MockForwardBatch(ForwardBatch):
def __init__(self, max_chunk_capacity: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_chunk_capacity = max_chunk_capacity
def get_max_chunk_capacity(self):
return self.max_chunk_capacity
class MockReqToTokenPool:
def __init__(self, batch_size, seq_len, device):
self.req_to_token = (
torch.arange(batch_size * seq_len, device=device)
.reshape(batch_size, seq_len)
.to(torch.int32)
)
# Test correctness of triton kernel for computing kv indices
def check_kv_indices(forward_batch):
for i in range(forward_batch.num_prefix_chunks):
computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i]
req_to_token = forward_batch.req_to_token_pool.req_to_token[
: forward_batch.batch_size, :
]
ref_kv_indices = torch.empty(
forward_batch.prefix_chunk_num_tokens[i],
dtype=torch.int32,
device=computed_kv_indices.device,
)
running_ptr = 0
for j in range(forward_batch.batch_size):
seq_start = forward_batch.prefix_chunk_starts[i, j].item()
seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item()
ref_kv_indices[running_ptr : running_ptr + seq_len].copy_(
req_to_token[j, seq_start : seq_start + seq_len]
)
running_ptr += seq_len
assert torch.allclose(computed_kv_indices, ref_kv_indices)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
class TestPrefixChunkInfo(CustomTestCase):
def setUp(self):
# Common test parameters
self.num_local_heads = 128
self.kv_lora_rank = 512
self.qk_rope_head_dim = 64
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.extend_len = 64
self.max_bs = 4
self.max_seq_len = 128
# req_to_token_pool
self.req_to_token_pool = MockReqToTokenPool(
self.max_bs,
self.max_seq_len,
self.device,
)
# token_to_kv_pool
self.token_to_kv_pool = MLATokenToKVPool(
size=self.max_bs * self.max_seq_len,
page_size=1, # only consider page=1 for unit test
dtype=self.dtype,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
layer_num=1, # only consider layer=1 for unit test
device=self.device,
enable_memory_saver=False,
)
def test_prefix_chunk_info(self):
"""Test the standard extend operation."""
for test_case in TEST_CASES:
print(
f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}"
)
batch_size = test_case["batch_size"]
prefix_lens_cpu = test_case["prefix_lens"]
assert len(prefix_lens_cpu) == batch_size
prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device)
max_chunk_capacity = test_case["max_chunk_capacity"]
seq_lens_cpu = [
self.extend_len + prefix_lens_cpu[i] for i in range(batch_size)
]
seq_lens = torch.tensor(seq_lens_cpu, device=self.device)
# Create forward batch
# input_ids and out_cache_loc are dummy tensors in this test
forward_batch = MockForwardBatch(
max_chunk_capacity=max_chunk_capacity,
batch_size=batch_size,
input_ids=torch.randint(
0, 100, (batch_size, self.extend_len), device=self.device
),
out_cache_loc=torch.arange(
self.max_bs * self.max_seq_len - batch_size * self.extend_len,
self.max_bs * self.max_seq_len,
device=self.device,
),
seq_lens_sum=sum(seq_lens_cpu),
forward_mode=ForwardMode.EXTEND,
req_pool_indices=torch.arange(batch_size, device=self.device),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
extend_prefix_lens=prefix_lens,
extend_prefix_lens_cpu=prefix_lens_cpu,
)
forward_batch.req_to_token_pool = self.req_to_token_pool
forward_batch.token_to_kv_pool = self.token_to_kv_pool
forward_batch.prepare_chunked_prefix_cache_info(self.device)
assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity
assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"]
assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"]
assert torch.allclose(
forward_batch.prefix_chunk_starts,
test_case["prefix_chunk_starts"].to(self.device),
)
assert torch.allclose(
forward_batch.prefix_chunk_seq_lens,
test_case["prefix_chunk_seq_lens"].to(self.device),
)
check_kv_indices(forward_batch)
if __name__ == "__main__":
unittest.main()