Support MHA with chunked prefix cache for DeepSeek chunked prefill (#5113)
This commit is contained in:
@@ -181,6 +181,28 @@ class ForwardBatch:
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For MLA chunked prefix cache used in chunked prefill
|
||||
# Tell attention backend whether the kv cache needs to be attended in current pass
|
||||
attn_attend_prefix_cache: Optional[bool] = None
|
||||
# Number of prefix cache chunks
|
||||
num_prefix_chunks: Optional[int] = None
|
||||
# Index of current chunk, used by attention backend
|
||||
prefix_chunk_idx: Optional[int] = None
|
||||
# Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
|
||||
prefix_chunk_len: Optional[int] = None
|
||||
# Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
|
||||
prefix_chunk_starts: Optional[torch.Tensor] = None
|
||||
# Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
|
||||
prefix_chunk_seq_lens: Optional[torch.Tensor] = None
|
||||
# Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
|
||||
prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
|
||||
# Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
|
||||
prefix_chunk_max_seq_lens: Optional[List[int]] = None
|
||||
# Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
|
||||
prefix_chunk_num_tokens: Optional[List[int]] = None
|
||||
# KV Indices for each chunk
|
||||
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# For multimodal
|
||||
mm_inputs: Optional[List[MultimodalInputs]] = None
|
||||
|
||||
@@ -484,6 +506,128 @@ class ForwardBatch:
|
||||
)
|
||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||
|
||||
def get_max_chunk_capacity(self):
|
||||
# Maximum number of tokens in each chunk
|
||||
# TODO: Should be changed to a better value, maybe passed through server args
|
||||
return 128 * 1024
|
||||
|
||||
def set_prefix_chunk_idx(self, idx: int):
|
||||
self.prefix_chunk_idx = idx
|
||||
|
||||
def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
|
||||
self.attn_attend_prefix_cache = attn_attend_prefix_cache
|
||||
|
||||
def prepare_chunked_kv_indices(self, device: torch.device):
|
||||
self.prefix_chunk_kv_indices = []
|
||||
for idx in range(self.num_prefix_chunks):
|
||||
chunk_starts = self.prefix_chunk_starts[idx]
|
||||
chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
|
||||
chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
|
||||
num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
|
||||
|
||||
chunk_kv_indices = torch.empty(
|
||||
num_chunk_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
|
||||
self.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
chunk_starts,
|
||||
chunk_seq_lens,
|
||||
chunk_cu_seq_lens,
|
||||
chunk_kv_indices,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
)
|
||||
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
|
||||
|
||||
# Here we suppose the length of each chunk is equal
|
||||
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
|
||||
# num_prefix_chunks = cdiv(1024, 256) = 4
|
||||
# prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
|
||||
# prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
|
||||
# prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
|
||||
# TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
|
||||
def get_prefix_chunk_seq_lens(
|
||||
self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
|
||||
):
|
||||
device = prefix_lens.device
|
||||
prefix_chunk_starts = (
|
||||
torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.batch_size)
|
||||
* prefix_chunk_len
|
||||
)
|
||||
prefix_chunk_ends = torch.min(
|
||||
prefix_lens.unsqueeze(0),
|
||||
prefix_chunk_starts + prefix_chunk_len,
|
||||
).to(torch.int32)
|
||||
|
||||
prefix_chunk_seq_lens = (
|
||||
(prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
|
||||
)
|
||||
|
||||
return prefix_chunk_starts, prefix_chunk_seq_lens
|
||||
|
||||
# Called before each attention module if using chunked kv cache for prefill
|
||||
# Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
|
||||
def prepare_chunked_prefix_cache_info(self, device: torch.device):
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
|
||||
assert isinstance(
|
||||
self.token_to_kv_pool, MLATokenToKVPool
|
||||
), "Currently chunked prefix cache can only be used by Deepseek models"
|
||||
|
||||
if self.prefix_chunk_len is not None:
|
||||
# Chunked kv cache info already prepared by prior modules
|
||||
return
|
||||
|
||||
self.prefix_chunk_idx = -1
|
||||
|
||||
# chunk_capacity is the maximum number of tokens in each chunk
|
||||
chunk_capacity = self.get_max_chunk_capacity()
|
||||
self.prefix_chunk_len = chunk_capacity // self.batch_size
|
||||
|
||||
self.num_prefix_chunks = (
|
||||
max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
|
||||
) // self.prefix_chunk_len
|
||||
|
||||
# Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
|
||||
prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
|
||||
self.get_prefix_chunk_seq_lens(
|
||||
self.extend_prefix_lens,
|
||||
self.num_prefix_chunks,
|
||||
self.prefix_chunk_len,
|
||||
)
|
||||
)
|
||||
_, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
|
||||
torch.tensor(self.extend_prefix_lens_cpu),
|
||||
self.num_prefix_chunks,
|
||||
self.prefix_chunk_len,
|
||||
)
|
||||
self.prefix_chunk_starts = prefix_chunk_starts_cuda
|
||||
self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
|
||||
|
||||
# Metadata for attention backend
|
||||
self.prefix_chunk_cu_seq_lens = torch.zeros(
|
||||
self.num_prefix_chunks,
|
||||
self.batch_size + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
|
||||
dim=1
|
||||
).to(torch.int32)
|
||||
self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
|
||||
dim=1
|
||||
).values.tolist()
|
||||
|
||||
self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
|
||||
assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
|
||||
|
||||
# Precompute the kv indices for each chunk
|
||||
self.prepare_chunked_kv_indices(device)
|
||||
|
||||
|
||||
def compute_position_triton(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||
@@ -561,3 +705,40 @@ def compute_position_torch(
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_chunked_prefix_cache_kv_indices(
|
||||
req_to_token_ptr, # (max_batch, max_context_len,)
|
||||
req_pool_indices_ptr, # (batch_size,)
|
||||
chunk_start_idx_ptr, # (batch_size,)
|
||||
chunk_seq_lens_ptr, # (batch_size,)
|
||||
chunk_cu_seq_lens_ptr, # (batch_size + 1,)
|
||||
chunk_kv_indices_ptr, # (num_chunk_tokens,)
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# find the req pool idx, this is for batch to token
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
|
||||
|
||||
# get the token positions of current chunk
|
||||
chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
|
||||
chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < chunk_seq_len
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ chunk_start_pos
|
||||
+ offset,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
|
||||
)
|
||||
|
||||
@@ -167,6 +167,7 @@ class ModelRunner:
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
||||
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
}
|
||||
)
|
||||
@@ -318,6 +319,16 @@ class ModelRunner:
|
||||
if server_args.enable_deepep_moe:
|
||||
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
||||
|
||||
if not self.use_mla_backend:
|
||||
logger.info("Disable chunked prefix cache for non-MLA backend.")
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
elif self.page_size > 1:
|
||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
|
||||
if not server_args.disable_chunked_prefix_cache:
|
||||
logger.info("Chunked prefix cache is turned on.")
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user