support speculative decoding kernel in sgl-kernel (#3373)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Yineng Zhang
2025-02-07 20:29:51 +08:00
committed by GitHub
parent 45c87e083f
commit f9905d59a8
13 changed files with 1298 additions and 132 deletions

View File

@@ -1,124 +1,175 @@
import cutex
# NOTE: Please run this file to make sure the test cases are correct.
from typing import List
import torch
# parent_table [bs,topk*depth+)]
# selected_index [bs,draft_token_num-1)]
# verified_seq_len [bs]
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
# positions [bs*draft_token]
# retrive_index [b, draft_token, depth+2]
kernels = cutex.SourceModule(
"""
//cuda
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num){
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for(int i=0; i<bid; i++){
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
for(int i=0; i<draft_token_num-1; i++){
tree_mask[token_tree_idx+i] = false;
}
from sglang.srt.utils import is_cuda_available
int position = 0;
if (tid==0){
positions[bid*draft_token_num] = seq_len;
retrive_index[bid][0][0] = bid * draft_token_num;
return;
}
int depends_order[10];
int cur_position = tid-1;
while(true){
depends_order[position] = cur_position+1;
position += 1;
tree_mask[token_tree_idx+cur_position] = true;
int parent_tb_idx = selected_index[bid][cur_position]/topk;
if(parent_tb_idx==0){
break;
}
int token_idx = parent_list[bid][parent_tb_idx];
for(cur_position=0; cur_position<draft_token_num;cur_position++){
if(selected_index[bid][cur_position]==token_idx){
break;
}
}
}
positions[bid*draft_token_num+tid] = position + seq_len;
int is_leaf = 0;
for(int i=1;i<draft_token_num;i++){
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
{
is_leaf ++;
}
}
if(is_leaf==1){
for(int i=0; i<position; i++){
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
}
retrive_index[bid][tid][0] = bid*draft_token_num;
}
if is_cuda_available():
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
}
//!cuda
""",
float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
)
def build_tree_kernel(
parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
def build_tree_kernel_efficient_preprocess(
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
parents_list: List[torch.Tensor],
num_verify_tokens: int,
):
score_list = torch.cat(score_list, dim=1).flatten(
1
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
ss_token_list = torch.cat(
token_list, dim=1
) # b, (self.topk + (num_steps-1) * self.topk)
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
parent_list = torch.cat(parents_list[:-1], dim=1)
return parent_list, top_scores_index, draft_tokens
def build_tree_kernel_efficient(
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
parents_list: List[torch.Tensor],
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_tokens: int,
):
parent_list, top_scores_index, draft_tokens = (
build_tree_kernel_efficient_preprocess(
verified_id,
score_list,
token_list,
parents_list,
num_verify_tokens,
)
)
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
bs = seq_lens.numel()
device = parent_list.device
device = seq_lens.device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
tree_mask = torch.full(
(seq_lens_sum * draft_token + draft_token * draft_token * bs,),
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
retrive_index = torch.full(
(bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
)
positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
retrive_next_token = torch.full(
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
)
retrive_next_sibling = torch.full(
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
)
# position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9]
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
kernels.build_tree(
sgl_build_tree_kernel_efficient(
parent_list,
top_score_index,
top_scores_index,
seq_lens.to(torch.int32),
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
spec_steps,
num_verify_tokens,
)
return (
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
)
def build_tree_kernel(
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
parents_list: List[torch.Tensor],
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_tokens: int,
):
parent_list, top_scores_index, draft_tokens = (
build_tree_kernel_efficient_preprocess(
verified_id,
score_list,
token_list,
parents_list,
num_verify_tokens,
)
)
bs = seq_lens.numel()
device = seq_lens.device
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
retrive_index = torch.full(
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
)
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
sgl_build_tree_kernel(
parent_list,
top_scores_index,
seq_lens.to(torch.int32),
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token,
grid=(bs, 1, 1),
block=(64, 1, 1),
spec_steps,
num_verify_tokens,
)
index = retrive_index.sum(dim=-1) != -depth - 2
index = retrive_index.sum(dim=-1) != -spec_steps - 2
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
retrive_cum_len = torch.zeros(
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
)
retrive_cum_len[1:] = cum_len
# TODO: this indexing cause a synchronization, optimize this
retrive_index = retrive_index[index]
return tree_mask, positions, retrive_index, retrive_cum_len
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
if __name__ == "__main__":
def test_build_tree_kernel():
def findp(p_i, index, parent_list):
pos = index // 10
index_list = index.tolist()
@@ -311,21 +362,21 @@ if __name__ == "__main__":
bs = verified_seq_len.shape[0]
topk = 10
depth = 5 # depth <= 10
draft_token = 64
num_draft_token = 64
tree_mask = torch.full(
(
torch.sum(verified_seq_len).item() * draft_token
+ draft_token * draft_token * bs,
torch.sum(verified_seq_len).item() * num_draft_token
+ num_draft_token * num_draft_token * bs,
),
True,
).cuda()
retrive_index = torch.full(
(bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
)
positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
kernels.build_tree(
sgl_build_tree_kernel(
parent_list.unsqueeze(0),
index.unsqueeze(0),
verified_seq_len,
@@ -334,16 +385,345 @@ if __name__ == "__main__":
retrive_index,
topk,
depth,
draft_token,
grid=(bs, 1, 1),
block=(64, 1, 1),
num_draft_token,
)
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
c_mask, c_positions, c_retive_index = create_mask(
verified_seq_len, draft_token, index, parent_list, depth
verified_seq_len, num_draft_token, index, parent_list, depth
)
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
assert torch.allclose(positions, c_positions), "positions has error."
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
def test_build_tree_kernel_efficient():
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
score_list = [
torch.tensor(
[
[[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]],
[[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]],
],
dtype=torch.float32,
device="cuda",
),
torch.tensor(
[
[
[6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03],
[2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04],
[2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08],
[1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05],
],
[
[8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02],
[2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04],
[2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05],
[1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08],
],
],
dtype=torch.float32,
device="cuda",
),
torch.tensor(
[
[
[6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06],
[2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04],
[1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05],
[1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05],
],
[
[2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03],
[6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05],
[4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04],
[7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04],
],
],
dtype=torch.float32,
device="cuda",
),
torch.tensor(
[
[
[6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06],
[1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03],
[2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05],
[1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05],
],
[
[5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04],
[1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03],
[6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03],
[1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04],
],
],
dtype=torch.float32,
device="cuda",
),
]
token_list = [
torch.tensor(
[[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]],
dtype=torch.int64,
device="cuda",
),
torch.tensor(
[
[
29889,
29974,
29945,
29900,
29974,
29922,
29930,
29958,
29889,
29974,
29930,
29945,
29974,
29922,
29930,
29958,
],
[
22550,
4136,
16492,
8439,
29871,
2,
3001,
13,
2,
13,
29906,
29946,
2,
13,
29871,
259,
],
],
device="cuda",
),
torch.tensor(
[
[
29946,
29945,
29953,
29906,
29896,
29945,
29900,
29906,
29896,
29945,
29906,
29953,
29896,
29945,
29906,
29946,
],
[
29871,
2,
29901,
29889,
29871,
2,
395,
259,
29901,
29871,
2,
29889,
3001,
1234,
7146,
2186,
],
],
device="cuda",
),
torch.tensor(
[
[
29946,
29974,
29945,
29930,
29889,
29922,
29974,
29930,
29974,
29946,
29930,
29922,
29889,
29974,
29945,
29922,
],
[
29941,
29906,
2,
29946,
29871,
450,
319,
14990,
29946,
29941,
2,
29906,
29871,
2,
3001,
13,
],
],
device="cuda",
),
]
parents_list = [
torch.tensor(
[[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda"
),
torch.tensor([[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda"),
torch.tensor(
[[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda"
),
torch.tensor(
[[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda"
),
]
seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda")
topk = 4
depth = 4
num_draft_token = 8
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id=verified_id,
score_list=score_list,
token_list=token_list,
parents_list=parents_list,
seq_lens=seq_lens,
seq_lens_sum=torch.sum(seq_lens).item(),
topk=topk,
spec_steps=depth,
num_verify_tokens=num_draft_token,
)
)
from sglang.srt.utils import first_rank_print
first_rank_print("=========== build tree kernel ==========")
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print(f"{position=}", flush=True)
first_rank_print(f"{retrive_index=}", flush=True)
first_rank_print(f"{retrive_cum_len=}", flush=True)
first_rank_print(f"{draft_tokens=}", flush=True)
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [
[0, -1, -1, -1, -1, -1],
[0, 2, 4, 6, -1, -1],
[0, 1, 3, 5, 7, -1],
[8, -1, -1, -1, -1, -1],
[8, 9, 10, -1, -1, -1],
[8, 9, 12, -1, -1, -1],
[8, 9, 13, -1, -1, -1],
[8, 9, 11, 14, 15, -1],
]
assert retrive_cum_len.tolist() == [0, 3, 8]
assert draft_tokens.tolist() == [
29974,
29896,
29906,
29889,
29974,
29946,
29896,
29946,
13,
13,
22550,
4136,
16492,
8439,
29871,
29941,
]
(
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
) = build_tree_kernel_efficient(
verified_id=verified_id,
score_list=score_list,
token_list=token_list,
parents_list=parents_list,
seq_lens=seq_lens,
seq_lens_sum=torch.sum(seq_lens).item(),
topk=topk,
spec_steps=depth,
num_verify_tokens=num_draft_token,
)
first_rank_print("=========== build tree kernel efficient ==========")
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print(f"{position=}", flush=True)
first_rank_print(f"{retrive_index=}", flush=True)
first_rank_print(f"{retrive_next_token=}", flush=True)
first_rank_print(f"{retrive_next_sibling=}", flush=True)
first_rank_print(f"{draft_tokens=}", flush=True)
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [
[0, 1, 2, 3, 4, 5, 6, 7],
[8, 9, 10, 11, 12, 13, 14, 15],
]
assert retrive_next_token.tolist() == [
[1, 3, 4, 5, 6, 7, -1, -1],
[1, 2, -1, 6, -1, -1, 7, -1],
]
assert retrive_next_sibling.tolist() == [
[-1, 2, -1, -1, -1, -1, -1, -1],
[-1, -1, 3, 4, 5, -1, -1, -1],
]
assert draft_tokens.tolist() == [
29974,
29896,
29906,
29889,
29974,
29946,
29896,
29946,
13,
13,
22550,
4136,
16492,
8439,
29871,
29941,
]
if __name__ == "__main__":
test_build_tree_kernel_efficient()
test_build_tree_kernel()

View File

@@ -258,39 +258,77 @@ class EagleVerifyInput:
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
predict = torch.cat(
[predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
)
draft_token = torch.cat(
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
dim=-1,
)
target_predict = predict[self.retrive_index]
candidates = draft_token[self.retrive_index]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
bs = self.retrive_cum_len.numel() - 1
if batch.sampling_info.is_all_greedy:
# temp == 0
bs = self.retrive_cum_len.numel() - 1
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
predict = torch.cat(
[predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1
)
target_predict = predict[self.retrive_index]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
max_draft_len = self.retrive_index.shape[-1]
accept_index = torch.full(
(bs, max_draft_len), -1, dtype=torch.long, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
eagle_verify_retrive[(bs,)](
self.retrive_index.contiguous(),
accept_mask.contiguous(),
self.retrive_cum_len,
accept_index,
accept_length,
extract_index,
max_draft_len,
self.draft_token_num,
triton.next_power_of_2(max_draft_len),
)
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
max_draft_len = self.retrive_index.shape[-1]
accept_index = torch.full(
(bs, max_draft_len), -1, dtype=torch.int32, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
eagle_verify_retrive[(bs,)](
self.retrive_index.contiguous(),
accept_mask.contiguous(),
self.retrive_cum_len,
accept_index,
accept_length,
extract_index,
max_draft_len,
self.draft_token_num,
triton.next_power_of_2(max_draft_len),
)
else:
# temp > 0
bs = self.retrive_index.shape[0]
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1
target_logits = logits_output.next_token_logits[self.retrive_index]
predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda")
accept_index = torch.full(
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1)
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
draft_probs = torch.full_like(
target_probs, 0, dtype=torch.float32, device="cuda"
)
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
tree_speculative_sampling_target_only(
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32),
retrive_index=self.retrive_index.to(torch.int32),
retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
uniform_samples=coins,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True,
)
new_accept_index = []
unfinished_index = []