Add Support for Page Size greater than 1 for Flashinfer MLA Backend (#8593)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -24,9 +24,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -72,11 +70,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Parse constants
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.device = model_runner.device
|
||||
self.skip_prefill = skip_prefill
|
||||
self.page_size = model_runner.page_size
|
||||
|
||||
# Allocate buffers
|
||||
global global_workspace_buffer
|
||||
@@ -97,15 +95,25 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
else:
|
||||
self.kv_indptr = kv_indptr_buf
|
||||
|
||||
self.kv_indices = torch.empty(
|
||||
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
if q_indptr_decode_buf is None:
|
||||
# A hack to pre-initialize large batch size for dp attention
|
||||
if model_runner.server_args.enable_dp_attention:
|
||||
max_bs = model_runner.server_args.dp_size * max_bs
|
||||
self.q_indptr_decode = torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
else:
|
||||
self.q_indptr_decode = q_indptr_decode_buf
|
||||
|
||||
@@ -148,6 +156,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.prefill_cuda_graph_metadata = {} # For verify
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
forward_batch.req_pool_indices,
|
||||
@@ -205,16 +214,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if kv_indices_buf is None:
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
else:
|
||||
cuda_graph_kv_indices = kv_indices_buf
|
||||
|
||||
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
||||
self.cuda_graph_kv_indices = (
|
||||
self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
|
||||
)
|
||||
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
||||
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
||||
self.cuda_graph_kv_lens = torch.ones(
|
||||
@@ -240,6 +242,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
@@ -250,7 +253,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
||||
backend="auto",
|
||||
)
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices,
|
||||
@@ -321,11 +323,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
spec_info: Optional[SpecInfo],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
|
||||
if forward_mode.is_decode_or_idle():
|
||||
assert seq_lens_cpu is not None
|
||||
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
||||
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
|
||||
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
||||
kv_len_arr_cpu, dim=0
|
||||
num_pages_per_req, dim=0
|
||||
)
|
||||
self.fast_decode_kwargs.update(
|
||||
{
|
||||
@@ -334,7 +338,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
"kv_len_arr_cpu": kv_len_arr_cpu,
|
||||
}
|
||||
)
|
||||
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
@@ -381,7 +384,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
logits_soft_cap = layer.logit_cap
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
||||
@@ -401,7 +403,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
|
||||
if self.forward_metadata.use_ragged:
|
||||
# ragged prefill
|
||||
if q_rope is not None:
|
||||
@@ -422,6 +423,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
||||
q.dtype
|
||||
)
|
||||
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
||||
|
||||
if q_rope is None:
|
||||
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q, q_rope = (
|
||||
@@ -483,17 +486,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
||||
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
||||
|
||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
||||
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
||||
q.dtype
|
||||
)
|
||||
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
||||
|
||||
o = q_nope.new_empty(q_nope.shape)
|
||||
# Direct call to run without the wrapper
|
||||
o = decode_wrapper.run(
|
||||
q_nope,
|
||||
q_rope,
|
||||
k_buffer[:, :, : layer.v_head_dim],
|
||||
k_buffer[:, :, layer.v_head_dim :],
|
||||
k_buf[:, :, : layer.v_head_dim],
|
||||
k_buf[:, :, layer.v_head_dim :],
|
||||
out=o,
|
||||
)
|
||||
|
||||
@@ -512,9 +515,10 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
self.scaling = model_runner.model_config.scaling
|
||||
self.data_type = model_runner.dtype
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
self.page_size = model_runner.page_size
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_indices = attn_backend.kv_indices
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.q_indptr = attn_backend.q_indptr_decode
|
||||
|
||||
@@ -558,13 +562,17 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||
sm_scale = self.scaling
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
num_pages_per_req = (
|
||||
paged_kernel_lens + self.page_size - 1
|
||||
) // self.page_size
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = (
|
||||
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
||||
self.kv_indices[: kv_indptr[-1]]
|
||||
if not init_metadata_replay
|
||||
else fast_decode_kwargs["kv_indices"]
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
@@ -573,39 +581,40 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
self.page_size,
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
|
||||
if not init_metadata_replay:
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_lens,
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
1,
|
||||
False,
|
||||
sm_scale,
|
||||
self.data_type,
|
||||
self.data_type,
|
||||
qo_indptr=q_indptr,
|
||||
kv_indptr=kv_indptr,
|
||||
kv_indices=kv_indices,
|
||||
kv_len_arr=kv_lens,
|
||||
num_heads=self.num_local_heads,
|
||||
head_dim_ckv=self.kv_lora_rank,
|
||||
head_dim_kpe=self.qk_rope_head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=False,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=self.data_type,
|
||||
kv_data_type=self.data_type,
|
||||
)
|
||||
else:
|
||||
wrapper.plan(
|
||||
fast_decode_kwargs["qo_indptr_cpu"],
|
||||
fast_decode_kwargs["kv_indptr_cpu"],
|
||||
kv_indices,
|
||||
fast_decode_kwargs["kv_len_arr_cpu"],
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
1,
|
||||
False,
|
||||
sm_scale,
|
||||
self.data_type,
|
||||
self.data_type,
|
||||
qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
|
||||
kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
|
||||
kv_indices=kv_indices,
|
||||
kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
|
||||
num_heads=self.num_local_heads,
|
||||
head_dim_ckv=self.kv_lora_rank,
|
||||
head_dim_kpe=self.qk_rope_head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=False,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=self.data_type,
|
||||
kv_data_type=self.data_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -627,12 +636,14 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.qo_indptr = attn_backend.qo_indptr
|
||||
self.kv_indices = attn_backend.kv_indices
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||
self.page_size = model_runner.page_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
req_pool_indices: torch.Tnesor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
@@ -646,7 +657,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
|
||||
self.call_begin_forward(
|
||||
self.prefill_wrapper_ragged,
|
||||
prefill_wrapper_paged,
|
||||
@@ -680,13 +690,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
|
||||
if spec_info is None:
|
||||
assert len(seq_lens) == len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
num_pages_per_req = (
|
||||
paged_kernel_lens + self.page_size - 1
|
||||
) // self.page_size
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum,
|
||||
dtype=torch.int32,
|
||||
device=req_pool_indices.device,
|
||||
)
|
||||
kv_indices = self.kv_indices[: kv_indptr[-1]]
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
@@ -695,6 +704,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
self.page_size,
|
||||
)
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
@@ -712,7 +722,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
if use_ragged:
|
||||
# ragged prefill
|
||||
wrapper_ragged.begin_forward(
|
||||
@@ -726,20 +735,26 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
)
|
||||
else:
|
||||
# mla paged prefill
|
||||
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
||||
if spec_info is not None:
|
||||
assert (
|
||||
self.page_size == 1
|
||||
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
|
||||
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
|
||||
else:
|
||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||
wrapper_paged.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
1,
|
||||
True,
|
||||
sm_scale,
|
||||
self.q_data_type,
|
||||
self.data_type,
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=kv_indptr,
|
||||
kv_indices=kv_indices,
|
||||
kv_len_arr=kv_lens,
|
||||
num_heads=self.num_local_heads,
|
||||
head_dim_ckv=self.kv_lora_rank,
|
||||
head_dim_kpe=self.qk_rope_head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.data_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -834,6 +849,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
call_fn(i, forward_batch)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
|
||||
kv_indices = torch.zeros(
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
@@ -869,6 +885,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
@@ -885,6 +902,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, forward_batch: ForwardBatch, bs: int
|
||||
):
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
|
||||
@@ -9,18 +9,89 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_to_token_ptr,
|
||||
req_pool_indices_ptr,
|
||||
page_kernel_lens_ptr,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices_ptr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr = 1,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
"""
|
||||
Create KV indices for FlashInfer attention backend.
|
||||
|
||||
# find the req pool idx, this is for batch to token
|
||||
This Triton kernel builds a lookup table that maps from logical request/token
|
||||
coordinates to physical token locations in the global KV cache pool. It's used
|
||||
by FlashInfer attention backends to efficiently access scattered KV cache data.
|
||||
|
||||
The kernel processes each request in parallel and converts the req_to_token
|
||||
lookup table into a flat list of token indices that can be used by attention kernels.
|
||||
|
||||
general idea:
|
||||
blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
|
||||
fixed number of pages)]
|
||||
max_pages = max_context_len / PAGED_SIZE
|
||||
kv_indices_ptr will store the flat list of the pages used by each request
|
||||
Args:
|
||||
Inputs Arguments (non mutable):
|
||||
|
||||
req_to_token_ptr: Request to token location look up table
|
||||
Shape: [max_batch, max_context_len]
|
||||
req_pool_indices_ptr: Request to pool index look up table. Each request uses
|
||||
one pool.
|
||||
Shape: [batch_size]
|
||||
page_kernel_lens_ptr: sequence lengths per request
|
||||
Shape: [batch_size]
|
||||
kv_indptr: Should be computed based on number of pages used by each request.
|
||||
It is used by flashinfer attention kernels to index into the kv_indices_ptr.
|
||||
per request.
|
||||
Shape: [batch_size + 1]
|
||||
kv_indptr[i] = start index in kv_indices for request i
|
||||
kv_start_idx: Pointer to array containing start offsets for each request in SGL.
|
||||
Can be None. If provided, adds offset to token positions.
|
||||
|
||||
req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
|
||||
Equal to max_context_len.
|
||||
|
||||
PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
|
||||
|
||||
Outputs:
|
||||
kv_indices_ptr: Pointer to output array where KV indices will be stored.
|
||||
Shape:[total-num-pages],
|
||||
where total_num_pages = sum(seq_lens // PAGED_SIZE)
|
||||
|
||||
Example:
|
||||
If we have:
|
||||
- req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
|
||||
- page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
|
||||
- req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
|
||||
in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
|
||||
|
||||
The kernel will output:
|
||||
If PAGE_SIZE = 1:
|
||||
packed
|
||||
- kv_indptr (passed in as input arg): [0,3,5]
|
||||
- kv_indices = [10, 11, 12, 20, 21]
|
||||
padded - max_pages is 10 tokens per req
|
||||
- kv_indptr (passed in as input arg): [0,10, 20]
|
||||
- kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
|
||||
20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
|
||||
|
||||
If PAGE_SIZE = 2
|
||||
packed:
|
||||
- kv_indptr (passed in as input arg): [0,3,4]
|
||||
- kv_indices = [5,6,10]
|
||||
padded: max_pages is 4
|
||||
- kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
|
||||
- kv_indices = [5, 6, -1, -1,
|
||||
10, -1, -1, -1]
|
||||
This allows attention kernels to directly access the correct KV cache
|
||||
entries for each request's tokens.
|
||||
"""
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
|
||||
pid = tl.program_id(axis=0)
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||
|
||||
@@ -31,19 +102,27 @@ def create_flashinfer_kv_indices_triton(
|
||||
kv_end = kv_start
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
# index into req_to_token_ptr needs to be int64
|
||||
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
|
||||
mask = offset < kv_end - kv_start
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ kv_start
|
||||
+ offset,
|
||||
mask=mask,
|
||||
kv_range = kv_end - kv_start
|
||||
num_pages = tl.cdiv(kv_range, PAGE_SIZE)
|
||||
num_loops = tl.cdiv(kv_range, BLOCK_SIZE)
|
||||
req_to_token_block_start = (
|
||||
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start
|
||||
)
|
||||
for i in range(num_loops):
|
||||
token_offsets_in_block = (
|
||||
tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK
|
||||
) * PAGE_SIZE
|
||||
page_offsets_in_block = token_offsets_in_block // PAGE_SIZE
|
||||
valid_tokens = token_offsets_in_block < kv_range
|
||||
valid_pages = page_offsets_in_block < num_pages
|
||||
token_numbers = tl.load(
|
||||
req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
|
||||
)
|
||||
tl.store(
|
||||
kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
|
||||
token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
|
||||
mask=valid_pages,
|
||||
)
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@@ -636,6 +636,10 @@ class ServerArgs:
|
||||
logger.warning(
|
||||
"DeepSeek MTP does not require setting speculative_draft_model_path."
|
||||
)
|
||||
if self.page_size != 1 and self.attention_backend == "flashinfer":
|
||||
raise ValueError(
|
||||
"Speculative decoding with page_size != 1 is not supported. Please set page_size to 1."
|
||||
)
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
|
||||
Reference in New Issue
Block a user