Support MHA with chunked prefix cache for DeepSeek chunked prefill (#5113)
This commit is contained in:
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -593,41 +593,87 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
||||
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
||||
k_rope_cache = k_rope.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
layer.head_dim - layer.v_head_dim,
|
||||
)
|
||||
c_kv_cache = c_kv.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
if (
|
||||
not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
and forward_batch.attn_attend_prefix_cache is not None
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
# Do multi-head attention with chunked prefix cache
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
if forward_batch.attn_attend_prefix_cache:
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
||||
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
assert chunk_idx >= 0
|
||||
|
||||
output, lse, *rest = flash_attn_varlen_func(
|
||||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
else:
|
||||
# MHA for extend part of sequence without attending prefix kv cache
|
||||
output, lse, *rest = flash_attn_varlen_func(
|
||||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k=metadata.cu_seqlens_q,
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
max_seqlen_k=metadata.max_seq_len_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
return output, lse
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
||||
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
||||
k_rope_cache = k_rope.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
layer.head_dim - layer.v_head_dim,
|
||||
)
|
||||
c_kv_cache = c_kv.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
|
||||
@@ -83,6 +83,7 @@ global_server_args_dict = {
|
||||
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
||||
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
||||
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
||||
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -181,6 +181,28 @@ class ForwardBatch:
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For MLA chunked prefix cache used in chunked prefill
|
||||
# Tell attention backend whether the kv cache needs to be attended in current pass
|
||||
attn_attend_prefix_cache: Optional[bool] = None
|
||||
# Number of prefix cache chunks
|
||||
num_prefix_chunks: Optional[int] = None
|
||||
# Index of current chunk, used by attention backend
|
||||
prefix_chunk_idx: Optional[int] = None
|
||||
# Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
|
||||
prefix_chunk_len: Optional[int] = None
|
||||
# Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
|
||||
prefix_chunk_starts: Optional[torch.Tensor] = None
|
||||
# Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
|
||||
prefix_chunk_seq_lens: Optional[torch.Tensor] = None
|
||||
# Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
|
||||
prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
|
||||
# Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
|
||||
prefix_chunk_max_seq_lens: Optional[List[int]] = None
|
||||
# Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
|
||||
prefix_chunk_num_tokens: Optional[List[int]] = None
|
||||
# KV Indices for each chunk
|
||||
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# For multimodal
|
||||
mm_inputs: Optional[List[MultimodalInputs]] = None
|
||||
|
||||
@@ -484,6 +506,128 @@ class ForwardBatch:
|
||||
)
|
||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||
|
||||
def get_max_chunk_capacity(self):
|
||||
# Maximum number of tokens in each chunk
|
||||
# TODO: Should be changed to a better value, maybe passed through server args
|
||||
return 128 * 1024
|
||||
|
||||
def set_prefix_chunk_idx(self, idx: int):
|
||||
self.prefix_chunk_idx = idx
|
||||
|
||||
def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
|
||||
self.attn_attend_prefix_cache = attn_attend_prefix_cache
|
||||
|
||||
def prepare_chunked_kv_indices(self, device: torch.device):
|
||||
self.prefix_chunk_kv_indices = []
|
||||
for idx in range(self.num_prefix_chunks):
|
||||
chunk_starts = self.prefix_chunk_starts[idx]
|
||||
chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
|
||||
chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
|
||||
num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
|
||||
|
||||
chunk_kv_indices = torch.empty(
|
||||
num_chunk_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
|
||||
self.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
chunk_starts,
|
||||
chunk_seq_lens,
|
||||
chunk_cu_seq_lens,
|
||||
chunk_kv_indices,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
)
|
||||
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
|
||||
|
||||
# Here we suppose the length of each chunk is equal
|
||||
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
|
||||
# num_prefix_chunks = cdiv(1024, 256) = 4
|
||||
# prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
|
||||
# prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
|
||||
# prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
|
||||
# TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
|
||||
def get_prefix_chunk_seq_lens(
|
||||
self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
|
||||
):
|
||||
device = prefix_lens.device
|
||||
prefix_chunk_starts = (
|
||||
torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.batch_size)
|
||||
* prefix_chunk_len
|
||||
)
|
||||
prefix_chunk_ends = torch.min(
|
||||
prefix_lens.unsqueeze(0),
|
||||
prefix_chunk_starts + prefix_chunk_len,
|
||||
).to(torch.int32)
|
||||
|
||||
prefix_chunk_seq_lens = (
|
||||
(prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
|
||||
)
|
||||
|
||||
return prefix_chunk_starts, prefix_chunk_seq_lens
|
||||
|
||||
# Called before each attention module if using chunked kv cache for prefill
|
||||
# Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
|
||||
def prepare_chunked_prefix_cache_info(self, device: torch.device):
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
|
||||
assert isinstance(
|
||||
self.token_to_kv_pool, MLATokenToKVPool
|
||||
), "Currently chunked prefix cache can only be used by Deepseek models"
|
||||
|
||||
if self.prefix_chunk_len is not None:
|
||||
# Chunked kv cache info already prepared by prior modules
|
||||
return
|
||||
|
||||
self.prefix_chunk_idx = -1
|
||||
|
||||
# chunk_capacity is the maximum number of tokens in each chunk
|
||||
chunk_capacity = self.get_max_chunk_capacity()
|
||||
self.prefix_chunk_len = chunk_capacity // self.batch_size
|
||||
|
||||
self.num_prefix_chunks = (
|
||||
max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
|
||||
) // self.prefix_chunk_len
|
||||
|
||||
# Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
|
||||
prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
|
||||
self.get_prefix_chunk_seq_lens(
|
||||
self.extend_prefix_lens,
|
||||
self.num_prefix_chunks,
|
||||
self.prefix_chunk_len,
|
||||
)
|
||||
)
|
||||
_, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
|
||||
torch.tensor(self.extend_prefix_lens_cpu),
|
||||
self.num_prefix_chunks,
|
||||
self.prefix_chunk_len,
|
||||
)
|
||||
self.prefix_chunk_starts = prefix_chunk_starts_cuda
|
||||
self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
|
||||
|
||||
# Metadata for attention backend
|
||||
self.prefix_chunk_cu_seq_lens = torch.zeros(
|
||||
self.num_prefix_chunks,
|
||||
self.batch_size + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
|
||||
dim=1
|
||||
).to(torch.int32)
|
||||
self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
|
||||
dim=1
|
||||
).values.tolist()
|
||||
|
||||
self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
|
||||
assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
|
||||
|
||||
# Precompute the kv indices for each chunk
|
||||
self.prepare_chunked_kv_indices(device)
|
||||
|
||||
|
||||
def compute_position_triton(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||
@@ -561,3 +705,40 @@ def compute_position_torch(
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_chunked_prefix_cache_kv_indices(
|
||||
req_to_token_ptr, # (max_batch, max_context_len,)
|
||||
req_pool_indices_ptr, # (batch_size,)
|
||||
chunk_start_idx_ptr, # (batch_size,)
|
||||
chunk_seq_lens_ptr, # (batch_size,)
|
||||
chunk_cu_seq_lens_ptr, # (batch_size + 1,)
|
||||
chunk_kv_indices_ptr, # (num_chunk_tokens,)
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# find the req pool idx, this is for batch to token
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
|
||||
|
||||
# get the token positions of current chunk
|
||||
chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
|
||||
chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < chunk_seq_len
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ chunk_start_pos
|
||||
+ offset,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
|
||||
)
|
||||
|
||||
@@ -167,6 +167,7 @@ class ModelRunner:
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
||||
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
}
|
||||
)
|
||||
@@ -318,6 +319,16 @@ class ModelRunner:
|
||||
if server_args.enable_deepep_moe:
|
||||
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
||||
|
||||
if not self.use_mla_backend:
|
||||
logger.info("Disable chunked prefix cache for non-MLA backend.")
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
elif self.page_size > 1:
|
||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
|
||||
if not server_args.disable_chunked_prefix_cache:
|
||||
logger.info("Chunked prefix cache is turned on.")
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -78,7 +79,7 @@ _is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
else:
|
||||
@@ -94,6 +95,19 @@ expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AttnForwardMethod(IntEnum):
|
||||
|
||||
# Use multi-head attention
|
||||
MHA = auto()
|
||||
|
||||
# Use absorbed multi-latent attention
|
||||
MLA = auto()
|
||||
|
||||
# Use multi-head attention, but with KV cache chunked.
|
||||
# This method can avoid OOM when prefix lengths are long.
|
||||
MHA_CHUNKED_KV = auto()
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -694,30 +708,54 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||
"flashinfer_mla_disable_ragged"
|
||||
]
|
||||
self.disable_chunked_prefix_cache = global_server_args_dict[
|
||||
"disable_chunked_prefix_cache"
|
||||
]
|
||||
self.attention_backend = global_server_args_dict["attention_backend"]
|
||||
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
|
||||
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
||||
# TODO: Design a finer way to determine the threshold
|
||||
self.chunked_prefix_cache_threshold = 8192
|
||||
|
||||
def dispatch_attn_forward_method(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> AttnForwardMethod:
|
||||
if self.attention_backend == "flashinfer":
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
return (
|
||||
if (
|
||||
not self.flashinfer_mla_disable_ragged
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||
)
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
elif self.attention_backend == "fa3":
|
||||
# Flash Attention: Keep absorbing for all extend/decode
|
||||
return False
|
||||
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not self.disable_chunked_prefix_cache
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and sum(forward_batch.extend_prefix_lens_cpu)
|
||||
>= self.chunked_prefix_cache_threshold
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
return (
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||
)
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -731,8 +769,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states
|
||||
|
||||
if self.no_absorb(forward_batch):
|
||||
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
||||
|
||||
if attn_forward_method == AttnForwardMethod.MHA:
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
||||
return self.forward_normal_chunked_kv(
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
else:
|
||||
if _is_hip:
|
||||
if (
|
||||
@@ -1007,6 +1051,127 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def _chunked_prefix_attn_mha(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
accum_output: torch.Tensor,
|
||||
accum_lse: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert forward_batch.num_prefix_chunks is not None
|
||||
for i in range(forward_batch.num_prefix_chunks):
|
||||
forward_batch.set_prefix_chunk_idx(i)
|
||||
|
||||
# Fetch latent cache from memory pool with precomputed chunked kv indices
|
||||
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
self.attn_mha.layer_id
|
||||
)
|
||||
latent_cache = latent_cache_buf[
|
||||
forward_batch.prefix_chunk_kv_indices[i]
|
||||
].contiguous()
|
||||
|
||||
kv_a_normed, k_pe = latent_cache.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
kv_a_normed = kv_a_normed.squeeze(1).contiguous()
|
||||
kv = self.kv_b_proj(kv_a_normed)[0]
|
||||
kv = kv.view(
|
||||
-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
|
||||
)
|
||||
v = kv[..., self.qk_nope_head_dim :]
|
||||
k_nope = kv[..., : self.qk_nope_head_dim]
|
||||
|
||||
k = torch.empty(
|
||||
(
|
||||
k_nope.shape[0],
|
||||
self.num_local_heads,
|
||||
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=v.dtype,
|
||||
device=v.device,
|
||||
)
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||
lse = torch.transpose(lse, 0, 1).contiguous()
|
||||
tmp_output = torch.empty_like(accum_output)
|
||||
tmp_lse = torch.empty_like(accum_lse)
|
||||
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
||||
accum_output, accum_lse = tmp_output, tmp_lse
|
||||
|
||||
return accum_output
|
||||
|
||||
def forward_normal_chunked_kv(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
|
||||
# To avoid this, we split the kv cache into chunks and process them one after another.
|
||||
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
|
||||
# The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
|
||||
# will be helpful for understanding the purpose of this function.
|
||||
|
||||
# First do normal mha forward to get output for extended part
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
||||
kv = self.kv_b_proj(kv_a)[0]
|
||||
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope = kv[..., : self.qk_nope_head_dim]
|
||||
v = kv[..., self.qk_nope_head_dim :]
|
||||
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q[..., self.qk_nope_head_dim :] = q_pe
|
||||
k = torch.empty_like(q)
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
||||
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
||||
|
||||
# Save latent cache
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||
)
|
||||
|
||||
# Do mha for extended part without prefix
|
||||
forward_batch.set_attn_attend_prefix_cache(False)
|
||||
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||
lse = torch.transpose(lse, 0, 1).contiguous()
|
||||
|
||||
# Do mha attention with chunked prefix cache if there are any sequence with prefix
|
||||
if any(forward_batch.extend_prefix_lens_cpu):
|
||||
# Only initialize the info once
|
||||
if forward_batch.num_prefix_chunks is None:
|
||||
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
||||
|
||||
forward_batch.set_attn_attend_prefix_cache(True)
|
||||
attn_output = self._chunked_prefix_attn_mha(
|
||||
q=q,
|
||||
accum_output=attn_output,
|
||||
accum_lse=lse,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
|
||||
@@ -186,6 +186,7 @@ class ServerArgs:
|
||||
warmups: Optional[str] = None
|
||||
n_share_experts_fusion: int = 0
|
||||
disable_shared_experts_fusion: bool = False
|
||||
disable_chunked_prefix_cache: bool = False
|
||||
|
||||
# Debug tensor dumps
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
@@ -1130,6 +1131,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-chunked-prefix-cache",
|
||||
action="store_true",
|
||||
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
|
||||
)
|
||||
|
||||
# Server warmups
|
||||
parser.add_argument(
|
||||
|
||||
224
python/sglang/test/attention/test_prefix_chunk_info.py
Normal file
224
python/sglang/test/attention/test_prefix_chunk_info.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
TEST_CASES = [
|
||||
# Sequence with same prefix lens
|
||||
{
|
||||
"batch_size": 3,
|
||||
"prefix_lens": [64, 64, 64],
|
||||
"max_chunk_capacity": 48,
|
||||
"prefix_chunk_len": 16,
|
||||
"num_prefix_chunks": 4,
|
||||
"prefix_chunk_starts": torch.tensor(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[16, 16, 16],
|
||||
[32, 32, 32],
|
||||
[48, 48, 48],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"prefix_chunk_seq_lens": torch.tensor(
|
||||
[
|
||||
[16, 16, 16],
|
||||
[16, 16, 16],
|
||||
[16, 16, 16],
|
||||
[16, 16, 16],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
},
|
||||
# Sequence with different prefix lens
|
||||
{
|
||||
"batch_size": 4,
|
||||
"prefix_lens": [16, 32, 48, 64],
|
||||
"max_chunk_capacity": 64,
|
||||
"prefix_chunk_len": 16,
|
||||
"num_prefix_chunks": 4,
|
||||
"prefix_chunk_starts": torch.tensor(
|
||||
[
|
||||
[0, 0, 0, 0],
|
||||
[16, 16, 16, 16],
|
||||
[32, 32, 32, 32],
|
||||
[48, 48, 48, 48],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"prefix_chunk_seq_lens": torch.tensor(
|
||||
[
|
||||
[16, 16, 16, 16],
|
||||
[0, 16, 16, 16],
|
||||
[0, 0, 16, 16],
|
||||
[0, 0, 0, 16],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
},
|
||||
# Sequence with irregular shapes
|
||||
{
|
||||
"batch_size": 2,
|
||||
"prefix_lens": [1, 64],
|
||||
"max_chunk_capacity": 31,
|
||||
"prefix_chunk_len": 15,
|
||||
"num_prefix_chunks": 5,
|
||||
"prefix_chunk_starts": torch.tensor(
|
||||
[
|
||||
[0, 0],
|
||||
[15, 15],
|
||||
[30, 30],
|
||||
[45, 45],
|
||||
[60, 60],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"prefix_chunk_seq_lens": torch.tensor(
|
||||
[
|
||||
[1, 15],
|
||||
[0, 15],
|
||||
[0, 15],
|
||||
[0, 15],
|
||||
[0, 4],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class MockForwardBatch(ForwardBatch):
|
||||
def __init__(self, max_chunk_capacity: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_chunk_capacity = max_chunk_capacity
|
||||
|
||||
def get_max_chunk_capacity(self):
|
||||
return self.max_chunk_capacity
|
||||
|
||||
|
||||
class MockReqToTokenPool:
|
||||
def __init__(self, batch_size, seq_len, device):
|
||||
self.req_to_token = (
|
||||
torch.arange(batch_size * seq_len, device=device)
|
||||
.reshape(batch_size, seq_len)
|
||||
.to(torch.int32)
|
||||
)
|
||||
|
||||
|
||||
# Test correctness of triton kernel for computing kv indices
|
||||
def check_kv_indices(forward_batch):
|
||||
for i in range(forward_batch.num_prefix_chunks):
|
||||
computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i]
|
||||
req_to_token = forward_batch.req_to_token_pool.req_to_token[
|
||||
: forward_batch.batch_size, :
|
||||
]
|
||||
ref_kv_indices = torch.empty(
|
||||
forward_batch.prefix_chunk_num_tokens[i],
|
||||
dtype=torch.int32,
|
||||
device=computed_kv_indices.device,
|
||||
)
|
||||
running_ptr = 0
|
||||
for j in range(forward_batch.batch_size):
|
||||
seq_start = forward_batch.prefix_chunk_starts[i, j].item()
|
||||
seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item()
|
||||
ref_kv_indices[running_ptr : running_ptr + seq_len].copy_(
|
||||
req_to_token[j, seq_start : seq_start + seq_len]
|
||||
)
|
||||
running_ptr += seq_len
|
||||
assert torch.allclose(computed_kv_indices, ref_kv_indices)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
||||
class TestPrefixChunkInfo(CustomTestCase):
|
||||
def setUp(self):
|
||||
# Common test parameters
|
||||
self.num_local_heads = 128
|
||||
self.kv_lora_rank = 512
|
||||
self.qk_rope_head_dim = 64
|
||||
self.device = torch.device("cuda")
|
||||
self.dtype = torch.bfloat16
|
||||
self.extend_len = 64
|
||||
self.max_bs = 4
|
||||
self.max_seq_len = 128
|
||||
|
||||
# req_to_token_pool
|
||||
self.req_to_token_pool = MockReqToTokenPool(
|
||||
self.max_bs,
|
||||
self.max_seq_len,
|
||||
self.device,
|
||||
)
|
||||
|
||||
# token_to_kv_pool
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
size=self.max_bs * self.max_seq_len,
|
||||
page_size=1, # only consider page=1 for unit test
|
||||
dtype=self.dtype,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
layer_num=1, # only consider layer=1 for unit test
|
||||
device=self.device,
|
||||
enable_memory_saver=False,
|
||||
)
|
||||
|
||||
def test_prefix_chunk_info(self):
|
||||
"""Test the standard extend operation."""
|
||||
|
||||
for test_case in TEST_CASES:
|
||||
print(
|
||||
f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}"
|
||||
)
|
||||
batch_size = test_case["batch_size"]
|
||||
prefix_lens_cpu = test_case["prefix_lens"]
|
||||
assert len(prefix_lens_cpu) == batch_size
|
||||
prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device)
|
||||
max_chunk_capacity = test_case["max_chunk_capacity"]
|
||||
seq_lens_cpu = [
|
||||
self.extend_len + prefix_lens_cpu[i] for i in range(batch_size)
|
||||
]
|
||||
seq_lens = torch.tensor(seq_lens_cpu, device=self.device)
|
||||
|
||||
# Create forward batch
|
||||
# input_ids and out_cache_loc are dummy tensors in this test
|
||||
forward_batch = MockForwardBatch(
|
||||
max_chunk_capacity=max_chunk_capacity,
|
||||
batch_size=batch_size,
|
||||
input_ids=torch.randint(
|
||||
0, 100, (batch_size, self.extend_len), device=self.device
|
||||
),
|
||||
out_cache_loc=torch.arange(
|
||||
self.max_bs * self.max_seq_len - batch_size * self.extend_len,
|
||||
self.max_bs * self.max_seq_len,
|
||||
device=self.device,
|
||||
),
|
||||
seq_lens_sum=sum(seq_lens_cpu),
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
req_pool_indices=torch.arange(batch_size, device=self.device),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
extend_prefix_lens=prefix_lens,
|
||||
extend_prefix_lens_cpu=prefix_lens_cpu,
|
||||
)
|
||||
forward_batch.req_to_token_pool = self.req_to_token_pool
|
||||
forward_batch.token_to_kv_pool = self.token_to_kv_pool
|
||||
|
||||
forward_batch.prepare_chunked_prefix_cache_info(self.device)
|
||||
assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity
|
||||
assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"]
|
||||
assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"]
|
||||
assert torch.allclose(
|
||||
forward_batch.prefix_chunk_starts,
|
||||
test_case["prefix_chunk_starts"].to(self.device),
|
||||
)
|
||||
assert torch.allclose(
|
||||
forward_batch.prefix_chunk_seq_lens,
|
||||
test_case["prefix_chunk_seq_lens"].to(self.device),
|
||||
)
|
||||
|
||||
check_kv_indices(forward_batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user