Support MHA with chunked prefix cache for DeepSeek chunked prefill (#5113)

This commit is contained in:
Baizhou Zhang
2025-04-15 22:01:22 -07:00
committed by GitHub
parent dd83e7e9c3
commit a42736bbb8
10 changed files with 734 additions and 46 deletions

View File

@@ -181,6 +181,28 @@ class ForwardBatch:
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether the kv cache needs to be attended in current pass
attn_attend_prefix_cache: Optional[bool] = None
# Number of prefix cache chunks
num_prefix_chunks: Optional[int] = None
# Index of current chunk, used by attention backend
prefix_chunk_idx: Optional[int] = None
# Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
prefix_chunk_len: Optional[int] = None
# Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
prefix_chunk_starts: Optional[torch.Tensor] = None
# Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
prefix_chunk_seq_lens: Optional[torch.Tensor] = None
# Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
# Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
prefix_chunk_max_seq_lens: Optional[List[int]] = None
# Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
prefix_chunk_num_tokens: Optional[List[int]] = None
# KV Indices for each chunk
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
# For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None
@@ -484,6 +506,128 @@ class ForwardBatch:
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def get_max_chunk_capacity(self):
# Maximum number of tokens in each chunk
# TODO: Should be changed to a better value, maybe passed through server args
return 128 * 1024
def set_prefix_chunk_idx(self, idx: int):
self.prefix_chunk_idx = idx
def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
self.attn_attend_prefix_cache = attn_attend_prefix_cache
def prepare_chunked_kv_indices(self, device: torch.device):
self.prefix_chunk_kv_indices = []
for idx in range(self.num_prefix_chunks):
chunk_starts = self.prefix_chunk_starts[idx]
chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
chunk_kv_indices = torch.empty(
num_chunk_tokens, dtype=torch.int32, device=device
)
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
# Here we suppose the length of each chunk is equal
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# num_prefix_chunks = cdiv(1024, 256) = 4
# prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
# prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
# prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
# TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
def get_prefix_chunk_seq_lens(
self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
):
device = prefix_lens.device
prefix_chunk_starts = (
torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, self.batch_size)
* prefix_chunk_len
)
prefix_chunk_ends = torch.min(
prefix_lens.unsqueeze(0),
prefix_chunk_starts + prefix_chunk_len,
).to(torch.int32)
prefix_chunk_seq_lens = (
(prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
)
return prefix_chunk_starts, prefix_chunk_seq_lens
# Called before each attention module if using chunked kv cache for prefill
# Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
def prepare_chunked_prefix_cache_info(self, device: torch.device):
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
assert isinstance(
self.token_to_kv_pool, MLATokenToKVPool
), "Currently chunked prefix cache can only be used by Deepseek models"
if self.prefix_chunk_len is not None:
# Chunked kv cache info already prepared by prior modules
return
self.prefix_chunk_idx = -1
# chunk_capacity is the maximum number of tokens in each chunk
chunk_capacity = self.get_max_chunk_capacity()
self.prefix_chunk_len = chunk_capacity // self.batch_size
self.num_prefix_chunks = (
max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
) // self.prefix_chunk_len
# Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
self.get_prefix_chunk_seq_lens(
self.extend_prefix_lens,
self.num_prefix_chunks,
self.prefix_chunk_len,
)
)
_, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
torch.tensor(self.extend_prefix_lens_cpu),
self.num_prefix_chunks,
self.prefix_chunk_len,
)
self.prefix_chunk_starts = prefix_chunk_starts_cuda
self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
# Metadata for attention backend
self.prefix_chunk_cu_seq_lens = torch.zeros(
self.num_prefix_chunks,
self.batch_size + 1,
device=device,
dtype=torch.int32,
)
self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
dim=1
).to(torch.int32)
self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
dim=1
).values.tolist()
self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
# Precompute the kv indices for each chunk
self.prepare_chunked_kv_indices(device)
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
@@ -561,3 +705,40 @@ def compute_position_torch(
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
@triton.jit
def create_chunked_prefix_cache_kv_indices(
req_to_token_ptr, # (max_batch, max_context_len,)
req_pool_indices_ptr, # (batch_size,)
chunk_start_idx_ptr, # (batch_size,)
chunk_seq_lens_ptr, # (batch_size,)
chunk_cu_seq_lens_ptr, # (batch_size + 1,)
chunk_kv_indices_ptr, # (num_chunk_tokens,)
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token
req_pool_index = tl.load(req_pool_indices_ptr + pid)
chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
# get the token positions of current chunk
chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < chunk_seq_len
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ chunk_start_pos
+ offset,
mask=mask,
)
tl.store(
chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
)

View File

@@ -167,6 +167,7 @@ class ModelRunner:
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"use_mla_backend": self.use_mla_backend,
}
)
@@ -318,6 +319,16 @@ class ModelRunner:
if server_args.enable_deepep_moe:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
if not self.use_mla_backend:
logger.info("Disable chunked prefix cache for non-MLA backend.")
server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache:
logger.info("Chunked prefix cache is turned on.")
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")