support speculative decoding kernel in sgl-kernel (#3373)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -1,124 +1,175 @@
|
||||
import cutex
|
||||
# NOTE: Please run this file to make sure the test cases are correct.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
# parent_table [bs,topk*depth+)]
|
||||
# selected_index [bs,draft_token_num-1)]
|
||||
# verified_seq_len [bs]
|
||||
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
|
||||
# positions [bs*draft_token]
|
||||
# retrive_index [b, draft_token, depth+2]
|
||||
kernels = cutex.SourceModule(
|
||||
"""
|
||||
//cuda
|
||||
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
|
||||
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
if (tid >= draft_token_num){
|
||||
return;
|
||||
}
|
||||
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
||||
for(int i=0; i<bid; i++){
|
||||
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
||||
}
|
||||
int seq_len = verified_seq_len[bid];
|
||||
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
|
||||
for(int i=0; i<draft_token_num-1; i++){
|
||||
tree_mask[token_tree_idx+i] = false;
|
||||
}
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
int position = 0;
|
||||
if (tid==0){
|
||||
positions[bid*draft_token_num] = seq_len;
|
||||
retrive_index[bid][0][0] = bid * draft_token_num;
|
||||
return;
|
||||
}
|
||||
|
||||
int depends_order[10];
|
||||
|
||||
int cur_position = tid-1;
|
||||
while(true){
|
||||
depends_order[position] = cur_position+1;
|
||||
position += 1;
|
||||
tree_mask[token_tree_idx+cur_position] = true;
|
||||
int parent_tb_idx = selected_index[bid][cur_position]/topk;
|
||||
if(parent_tb_idx==0){
|
||||
break;
|
||||
}
|
||||
|
||||
int token_idx = parent_list[bid][parent_tb_idx];
|
||||
for(cur_position=0; cur_position<draft_token_num;cur_position++){
|
||||
if(selected_index[bid][cur_position]==token_idx){
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
positions[bid*draft_token_num+tid] = position + seq_len;
|
||||
|
||||
int is_leaf = 0;
|
||||
for(int i=1;i<draft_token_num;i++){
|
||||
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
|
||||
{
|
||||
is_leaf ++;
|
||||
}
|
||||
}
|
||||
if(is_leaf==1){
|
||||
for(int i=0; i<position; i++){
|
||||
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
|
||||
}
|
||||
retrive_index[bid][tid][0] = bid*draft_token_num;
|
||||
}
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
|
||||
from sgl_kernel import (
|
||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||
)
|
||||
|
||||
|
||||
|
||||
}
|
||||
//!cuda
|
||||
""",
|
||||
float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
|
||||
boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel(
|
||||
parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
|
||||
def build_tree_kernel_efficient_preprocess(
|
||||
verified_id: torch.Tensor,
|
||||
score_list: List[torch.Tensor],
|
||||
token_list: List[torch.Tensor],
|
||||
parents_list: List[torch.Tensor],
|
||||
num_verify_tokens: int,
|
||||
):
|
||||
score_list = torch.cat(score_list, dim=1).flatten(
|
||||
1
|
||||
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
||||
ss_token_list = torch.cat(
|
||||
token_list, dim=1
|
||||
) # b, (self.topk + (num_steps-1) * self.topk)
|
||||
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
|
||||
top_scores_index = top_scores.indices
|
||||
top_scores_index = torch.sort(top_scores_index).values
|
||||
|
||||
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
||||
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
||||
parent_list = torch.cat(parents_list[:-1], dim=1)
|
||||
|
||||
return parent_list, top_scores_index, draft_tokens
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
verified_id: torch.Tensor,
|
||||
score_list: List[torch.Tensor],
|
||||
token_list: List[torch.Tensor],
|
||||
parents_list: List[torch.Tensor],
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
topk: int,
|
||||
spec_steps: int,
|
||||
num_verify_tokens: int,
|
||||
):
|
||||
parent_list, top_scores_index, draft_tokens = (
|
||||
build_tree_kernel_efficient_preprocess(
|
||||
verified_id,
|
||||
score_list,
|
||||
token_list,
|
||||
parents_list,
|
||||
num_verify_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
|
||||
bs = seq_lens.numel()
|
||||
device = parent_list.device
|
||||
device = seq_lens.device
|
||||
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
||||
# where each row indicates the attending pattern of each draft token
|
||||
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
||||
tree_mask = torch.full(
|
||||
(seq_lens_sum * draft_token + draft_token * draft_token * bs,),
|
||||
(
|
||||
seq_lens_sum * num_verify_tokens
|
||||
+ num_verify_tokens * num_verify_tokens * bs,
|
||||
),
|
||||
True,
|
||||
device=device,
|
||||
)
|
||||
retrive_index = torch.full(
|
||||
(bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
|
||||
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
||||
)
|
||||
positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
|
||||
retrive_next_token = torch.full(
|
||||
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
||||
)
|
||||
retrive_next_sibling = torch.full(
|
||||
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
||||
)
|
||||
# position: where each token belongs to
|
||||
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
||||
# then, positions = [7, 8, 8, 9]
|
||||
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
||||
|
||||
kernels.build_tree(
|
||||
sgl_build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
top_score_index,
|
||||
top_scores_index,
|
||||
seq_lens.to(torch.int32),
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
spec_steps,
|
||||
num_verify_tokens,
|
||||
)
|
||||
return (
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
draft_tokens,
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel(
|
||||
verified_id: torch.Tensor,
|
||||
score_list: List[torch.Tensor],
|
||||
token_list: List[torch.Tensor],
|
||||
parents_list: List[torch.Tensor],
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
topk: int,
|
||||
spec_steps: int,
|
||||
num_verify_tokens: int,
|
||||
):
|
||||
parent_list, top_scores_index, draft_tokens = (
|
||||
build_tree_kernel_efficient_preprocess(
|
||||
verified_id,
|
||||
score_list,
|
||||
token_list,
|
||||
parents_list,
|
||||
num_verify_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
bs = seq_lens.numel()
|
||||
device = seq_lens.device
|
||||
|
||||
tree_mask = torch.full(
|
||||
(
|
||||
seq_lens_sum * num_verify_tokens
|
||||
+ num_verify_tokens * num_verify_tokens * bs,
|
||||
),
|
||||
True,
|
||||
device=device,
|
||||
)
|
||||
retrive_index = torch.full(
|
||||
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
|
||||
)
|
||||
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
||||
|
||||
sgl_build_tree_kernel(
|
||||
parent_list,
|
||||
top_scores_index,
|
||||
seq_lens.to(torch.int32),
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
topk,
|
||||
depth,
|
||||
draft_token,
|
||||
grid=(bs, 1, 1),
|
||||
block=(64, 1, 1),
|
||||
spec_steps,
|
||||
num_verify_tokens,
|
||||
)
|
||||
index = retrive_index.sum(dim=-1) != -depth - 2
|
||||
|
||||
index = retrive_index.sum(dim=-1) != -spec_steps - 2
|
||||
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
||||
retrive_cum_len = torch.zeros(
|
||||
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
retrive_cum_len[1:] = cum_len
|
||||
# TODO: this indexing cause a synchronization, optimize this
|
||||
retrive_index = retrive_index[index]
|
||||
return tree_mask, positions, retrive_index, retrive_cum_len
|
||||
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def test_build_tree_kernel():
|
||||
def findp(p_i, index, parent_list):
|
||||
pos = index // 10
|
||||
index_list = index.tolist()
|
||||
@@ -311,21 +362,21 @@ if __name__ == "__main__":
|
||||
bs = verified_seq_len.shape[0]
|
||||
topk = 10
|
||||
depth = 5 # depth <= 10
|
||||
draft_token = 64
|
||||
num_draft_token = 64
|
||||
|
||||
tree_mask = torch.full(
|
||||
(
|
||||
torch.sum(verified_seq_len).item() * draft_token
|
||||
+ draft_token * draft_token * bs,
|
||||
torch.sum(verified_seq_len).item() * num_draft_token
|
||||
+ num_draft_token * num_draft_token * bs,
|
||||
),
|
||||
True,
|
||||
).cuda()
|
||||
retrive_index = torch.full(
|
||||
(bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
||||
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
||||
)
|
||||
positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
|
||||
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
|
||||
|
||||
kernels.build_tree(
|
||||
sgl_build_tree_kernel(
|
||||
parent_list.unsqueeze(0),
|
||||
index.unsqueeze(0),
|
||||
verified_seq_len,
|
||||
@@ -334,16 +385,345 @@ if __name__ == "__main__":
|
||||
retrive_index,
|
||||
topk,
|
||||
depth,
|
||||
draft_token,
|
||||
grid=(bs, 1, 1),
|
||||
block=(64, 1, 1),
|
||||
num_draft_token,
|
||||
)
|
||||
|
||||
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
||||
|
||||
c_mask, c_positions, c_retive_index = create_mask(
|
||||
verified_seq_len, draft_token, index, parent_list, depth
|
||||
verified_seq_len, num_draft_token, index, parent_list, depth
|
||||
)
|
||||
|
||||
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
||||
assert torch.allclose(positions, c_positions), "positions has error."
|
||||
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|
||||
|
||||
|
||||
def test_build_tree_kernel_efficient():
|
||||
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
|
||||
score_list = [
|
||||
torch.tensor(
|
||||
[
|
||||
[[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]],
|
||||
[[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03],
|
||||
[2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04],
|
||||
[2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08],
|
||||
[1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05],
|
||||
],
|
||||
[
|
||||
[8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02],
|
||||
[2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04],
|
||||
[2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05],
|
||||
[1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08],
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06],
|
||||
[2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04],
|
||||
[1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05],
|
||||
[1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05],
|
||||
],
|
||||
[
|
||||
[2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03],
|
||||
[6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05],
|
||||
[4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04],
|
||||
[7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04],
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06],
|
||||
[1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03],
|
||||
[2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05],
|
||||
[1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05],
|
||||
],
|
||||
[
|
||||
[5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04],
|
||||
[1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03],
|
||||
[6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03],
|
||||
[1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04],
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
),
|
||||
]
|
||||
token_list = [
|
||||
torch.tensor(
|
||||
[[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]],
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
29889,
|
||||
29974,
|
||||
29945,
|
||||
29900,
|
||||
29974,
|
||||
29922,
|
||||
29930,
|
||||
29958,
|
||||
29889,
|
||||
29974,
|
||||
29930,
|
||||
29945,
|
||||
29974,
|
||||
29922,
|
||||
29930,
|
||||
29958,
|
||||
],
|
||||
[
|
||||
22550,
|
||||
4136,
|
||||
16492,
|
||||
8439,
|
||||
29871,
|
||||
2,
|
||||
3001,
|
||||
13,
|
||||
2,
|
||||
13,
|
||||
29906,
|
||||
29946,
|
||||
2,
|
||||
13,
|
||||
29871,
|
||||
259,
|
||||
],
|
||||
],
|
||||
device="cuda",
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
29946,
|
||||
29945,
|
||||
29953,
|
||||
29906,
|
||||
29896,
|
||||
29945,
|
||||
29900,
|
||||
29906,
|
||||
29896,
|
||||
29945,
|
||||
29906,
|
||||
29953,
|
||||
29896,
|
||||
29945,
|
||||
29906,
|
||||
29946,
|
||||
],
|
||||
[
|
||||
29871,
|
||||
2,
|
||||
29901,
|
||||
29889,
|
||||
29871,
|
||||
2,
|
||||
395,
|
||||
259,
|
||||
29901,
|
||||
29871,
|
||||
2,
|
||||
29889,
|
||||
3001,
|
||||
1234,
|
||||
7146,
|
||||
2186,
|
||||
],
|
||||
],
|
||||
device="cuda",
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
29946,
|
||||
29974,
|
||||
29945,
|
||||
29930,
|
||||
29889,
|
||||
29922,
|
||||
29974,
|
||||
29930,
|
||||
29974,
|
||||
29946,
|
||||
29930,
|
||||
29922,
|
||||
29889,
|
||||
29974,
|
||||
29945,
|
||||
29922,
|
||||
],
|
||||
[
|
||||
29941,
|
||||
29906,
|
||||
2,
|
||||
29946,
|
||||
29871,
|
||||
450,
|
||||
319,
|
||||
14990,
|
||||
29946,
|
||||
29941,
|
||||
2,
|
||||
29906,
|
||||
29871,
|
||||
2,
|
||||
3001,
|
||||
13,
|
||||
],
|
||||
],
|
||||
device="cuda",
|
||||
),
|
||||
]
|
||||
parents_list = [
|
||||
torch.tensor(
|
||||
[[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda"
|
||||
),
|
||||
torch.tensor([[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda"),
|
||||
torch.tensor(
|
||||
[[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda"
|
||||
),
|
||||
torch.tensor(
|
||||
[[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda"
|
||||
),
|
||||
]
|
||||
seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda")
|
||||
topk = 4
|
||||
depth = 4
|
||||
num_draft_token = 8
|
||||
|
||||
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
||||
build_tree_kernel(
|
||||
verified_id=verified_id,
|
||||
score_list=score_list,
|
||||
token_list=token_list,
|
||||
parents_list=parents_list,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_sum=torch.sum(seq_lens).item(),
|
||||
topk=topk,
|
||||
spec_steps=depth,
|
||||
num_verify_tokens=num_draft_token,
|
||||
)
|
||||
)
|
||||
|
||||
from sglang.srt.utils import first_rank_print
|
||||
|
||||
first_rank_print("=========== build tree kernel ==========")
|
||||
# first_rank_print(f"{tree_mask=}", flush=True)
|
||||
first_rank_print(f"{position=}", flush=True)
|
||||
first_rank_print(f"{retrive_index=}", flush=True)
|
||||
first_rank_print(f"{retrive_cum_len=}", flush=True)
|
||||
first_rank_print(f"{draft_tokens=}", flush=True)
|
||||
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
||||
assert retrive_index.tolist() == [
|
||||
[0, -1, -1, -1, -1, -1],
|
||||
[0, 2, 4, 6, -1, -1],
|
||||
[0, 1, 3, 5, 7, -1],
|
||||
[8, -1, -1, -1, -1, -1],
|
||||
[8, 9, 10, -1, -1, -1],
|
||||
[8, 9, 12, -1, -1, -1],
|
||||
[8, 9, 13, -1, -1, -1],
|
||||
[8, 9, 11, 14, 15, -1],
|
||||
]
|
||||
assert retrive_cum_len.tolist() == [0, 3, 8]
|
||||
assert draft_tokens.tolist() == [
|
||||
29974,
|
||||
29896,
|
||||
29906,
|
||||
29889,
|
||||
29974,
|
||||
29946,
|
||||
29896,
|
||||
29946,
|
||||
13,
|
||||
13,
|
||||
22550,
|
||||
4136,
|
||||
16492,
|
||||
8439,
|
||||
29871,
|
||||
29941,
|
||||
]
|
||||
|
||||
(
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
draft_tokens,
|
||||
) = build_tree_kernel_efficient(
|
||||
verified_id=verified_id,
|
||||
score_list=score_list,
|
||||
token_list=token_list,
|
||||
parents_list=parents_list,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_sum=torch.sum(seq_lens).item(),
|
||||
topk=topk,
|
||||
spec_steps=depth,
|
||||
num_verify_tokens=num_draft_token,
|
||||
)
|
||||
|
||||
first_rank_print("=========== build tree kernel efficient ==========")
|
||||
# first_rank_print(f"{tree_mask=}", flush=True)
|
||||
first_rank_print(f"{position=}", flush=True)
|
||||
first_rank_print(f"{retrive_index=}", flush=True)
|
||||
first_rank_print(f"{retrive_next_token=}", flush=True)
|
||||
first_rank_print(f"{retrive_next_sibling=}", flush=True)
|
||||
first_rank_print(f"{draft_tokens=}", flush=True)
|
||||
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
||||
assert retrive_index.tolist() == [
|
||||
[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[8, 9, 10, 11, 12, 13, 14, 15],
|
||||
]
|
||||
assert retrive_next_token.tolist() == [
|
||||
[1, 3, 4, 5, 6, 7, -1, -1],
|
||||
[1, 2, -1, 6, -1, -1, 7, -1],
|
||||
]
|
||||
assert retrive_next_sibling.tolist() == [
|
||||
[-1, 2, -1, -1, -1, -1, -1, -1],
|
||||
[-1, -1, 3, 4, 5, -1, -1, -1],
|
||||
]
|
||||
assert draft_tokens.tolist() == [
|
||||
29974,
|
||||
29896,
|
||||
29906,
|
||||
29889,
|
||||
29974,
|
||||
29946,
|
||||
29896,
|
||||
29946,
|
||||
13,
|
||||
13,
|
||||
22550,
|
||||
4136,
|
||||
16492,
|
||||
8439,
|
||||
29871,
|
||||
29941,
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_build_tree_kernel_efficient()
|
||||
test_build_tree_kernel()
|
||||
|
||||
@@ -258,39 +258,77 @@ class EagleVerifyInput:
|
||||
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
||||
|
||||
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
|
||||
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||
predict = torch.cat(
|
||||
[predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
|
||||
)
|
||||
draft_token = torch.cat(
|
||||
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
|
||||
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
|
||||
dim=-1,
|
||||
)
|
||||
target_predict = predict[self.retrive_index]
|
||||
candidates = draft_token[self.retrive_index]
|
||||
# logits = logits_output.next_token_logits[self.retrive_index]
|
||||
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
|
||||
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
|
||||
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
|
||||
bs = self.retrive_cum_len.numel() - 1
|
||||
if batch.sampling_info.is_all_greedy:
|
||||
# temp == 0
|
||||
bs = self.retrive_cum_len.numel() - 1
|
||||
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||
predict = torch.cat(
|
||||
[predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1
|
||||
)
|
||||
target_predict = predict[self.retrive_index]
|
||||
# logits = logits_output.next_token_logits[self.retrive_index]
|
||||
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
|
||||
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
|
||||
|
||||
max_draft_len = self.retrive_index.shape[-1]
|
||||
accept_index = torch.full(
|
||||
(bs, max_draft_len), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
|
||||
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
|
||||
eagle_verify_retrive[(bs,)](
|
||||
self.retrive_index.contiguous(),
|
||||
accept_mask.contiguous(),
|
||||
self.retrive_cum_len,
|
||||
accept_index,
|
||||
accept_length,
|
||||
extract_index,
|
||||
max_draft_len,
|
||||
self.draft_token_num,
|
||||
triton.next_power_of_2(max_draft_len),
|
||||
)
|
||||
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
|
||||
max_draft_len = self.retrive_index.shape[-1]
|
||||
accept_index = torch.full(
|
||||
(bs, max_draft_len), -1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
|
||||
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
|
||||
eagle_verify_retrive[(bs,)](
|
||||
self.retrive_index.contiguous(),
|
||||
accept_mask.contiguous(),
|
||||
self.retrive_cum_len,
|
||||
accept_index,
|
||||
accept_length,
|
||||
extract_index,
|
||||
max_draft_len,
|
||||
self.draft_token_num,
|
||||
triton.next_power_of_2(max_draft_len),
|
||||
)
|
||||
else:
|
||||
# temp > 0
|
||||
bs = self.retrive_index.shape[0]
|
||||
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
||||
predict_shape[-1] += 1
|
||||
target_logits = logits_output.next_token_logits[self.retrive_index]
|
||||
predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda")
|
||||
accept_index = torch.full(
|
||||
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
||||
expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1)
|
||||
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
|
||||
draft_probs = torch.full_like(
|
||||
target_probs, 0, dtype=torch.float32, device="cuda"
|
||||
)
|
||||
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
||||
tree_speculative_sampling_target_only(
|
||||
predicts=predict, # mutable
|
||||
accept_index=accept_index, # mutable
|
||||
accept_token_num=accept_length, # mutable
|
||||
candidates=candidates.to(torch.int32),
|
||||
retrive_index=self.retrive_index.to(torch.int32),
|
||||
retrive_next_token=self.retrive_next_token.to(torch.int32),
|
||||
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
|
||||
uniform_samples=coins,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=global_server_args_dict[
|
||||
"speculative_accept_threshold_single"
|
||||
],
|
||||
threshold_acc=global_server_args_dict[
|
||||
"speculative_accept_threshold_acc"
|
||||
],
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
new_accept_index = []
|
||||
unfinished_index = []
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.0.3.post1"
|
||||
version = "0.0.3.post2"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -99,6 +99,8 @@ sources = [
|
||||
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/eagle_utils.cu",
|
||||
"src/sgl-kernel/csrc/speculative_sampling.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||
"3rdparty/flashinfer/csrc/norm.cu",
|
||||
|
||||
@@ -10,6 +10,8 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
|
||||
from sgl_kernel.ops import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
bmm_fp8,
|
||||
build_tree_kernel,
|
||||
build_tree_kernel_efficient,
|
||||
custom_dispose,
|
||||
custom_reduce,
|
||||
fp8_scaled_mm,
|
||||
@@ -31,6 +33,7 @@ from sgl_kernel.ops import (
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
tree_speculative_sampling_target_only,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -57,4 +60,7 @@ __all__ = [
|
||||
"top_k_renorm_prob",
|
||||
"top_k_top_p_sampling_from_probs",
|
||||
"top_p_renorm_prob",
|
||||
"tree_speculative_sampling_target_only",
|
||||
"build_tree_kernel_efficient",
|
||||
"build_tree_kernel",
|
||||
]
|
||||
|
||||
209
sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu
Normal file
209
sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu
Normal file
@@ -0,0 +1,209 @@
|
||||
/*
|
||||
* Copyright (c) 2025 by SGLang team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
// verified_seq_len [bs]
|
||||
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
|
||||
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
|
||||
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
|
||||
__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len,
|
||||
bool* tree_mask, int64_t* positions, int64_t* retrive_index,
|
||||
int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth,
|
||||
int draft_token_num) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid >= draft_token_num) {
|
||||
return;
|
||||
}
|
||||
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
||||
for (int i = 0; i < bid; i++) {
|
||||
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
||||
}
|
||||
int seq_len = verified_seq_len[bid];
|
||||
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
||||
for (int i = 0; i < draft_token_num - 1; i++) {
|
||||
tree_mask[token_tree_idx + i] = false;
|
||||
}
|
||||
|
||||
int position = 0;
|
||||
if (tid == 0) {
|
||||
positions[bid * draft_token_num] = seq_len;
|
||||
|
||||
int retrive_index_offset = bid * draft_token_num;
|
||||
for (int i = draft_token_num - 1; i > 0; --i) {
|
||||
int current_token_idx = retrive_index_offset + i;
|
||||
retrive_index[bid * draft_token_num + i] = current_token_idx;
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
|
||||
int parent_position = 0;
|
||||
if (parent_tb_idx > 0) {
|
||||
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (; parent_position < draft_token_num; ++parent_position) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
|
||||
++parent_position;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parent_position == draft_token_num) {
|
||||
printf(
|
||||
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
|
||||
"will be dropped.");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
|
||||
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||
} else {
|
||||
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
|
||||
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
|
||||
}
|
||||
}
|
||||
retrive_index[bid * draft_token_num] = bid * draft_token_num;
|
||||
} else {
|
||||
int cur_position = tid - 1;
|
||||
while (true) {
|
||||
position += 1;
|
||||
tree_mask[token_tree_idx + cur_position] = true;
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
||||
if (parent_tb_idx == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
positions[bid * draft_token_num + tid] = position + seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num) {
|
||||
// TODO (ying) check shape
|
||||
// TODO (ying) check type
|
||||
int bs = parent_list.size(0);
|
||||
dim3 grid(bs);
|
||||
dim3 block(draft_token_num);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()), static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int32_t(topk), int32_t(depth), int32_t(draft_token_num));
|
||||
}
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
// verified_seq_len [bs]
|
||||
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
|
||||
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
|
||||
// draft_token, depth + 2]
|
||||
__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask,
|
||||
int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid >= draft_token_num) {
|
||||
return;
|
||||
}
|
||||
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
||||
for (int i = 0; i < bid; i++) {
|
||||
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
||||
}
|
||||
int seq_len = verified_seq_len[bid];
|
||||
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
||||
for (int i = 0; i < draft_token_num - 1; i++) {
|
||||
tree_mask[token_tree_idx + i] = false;
|
||||
}
|
||||
|
||||
int position = 0;
|
||||
if (tid == 0) {
|
||||
positions[bid * draft_token_num] = seq_len;
|
||||
retrive_index[bid * draft_token_num * (depth + 2)] = bid * draft_token_num;
|
||||
return;
|
||||
}
|
||||
|
||||
int depends_order[10];
|
||||
|
||||
int cur_position = tid - 1;
|
||||
while (true) {
|
||||
depends_order[position] = cur_position + 1;
|
||||
position += 1;
|
||||
tree_mask[token_tree_idx + cur_position] = true;
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
||||
if (parent_tb_idx == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (cur_position = 0; cur_position < draft_token_num; cur_position++) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (cur_position == draft_token_num) {
|
||||
printf(
|
||||
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
|
||||
"will be dropped.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
positions[bid * draft_token_num + tid] = position + seq_len;
|
||||
|
||||
int is_leaf = 0;
|
||||
for (int i = 1; i < draft_token_num; i++) {
|
||||
if (tree_mask[seq_tree_idx + i * (draft_token_num + seq_len) + seq_len + tid]) {
|
||||
is_leaf++;
|
||||
}
|
||||
}
|
||||
if (is_leaf == 1) {
|
||||
for (int i = 0; i < position; i++) {
|
||||
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2) + position - i] =
|
||||
depends_order[i] + bid * draft_token_num;
|
||||
}
|
||||
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2)] = bid * draft_token_num;
|
||||
}
|
||||
}
|
||||
|
||||
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num) {
|
||||
// TODO (ying) check shape
|
||||
// TODO (ying) check type
|
||||
int bs = parent_list.size(0);
|
||||
dim3 grid(bs);
|
||||
dim3 block(draft_token_num);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
build_tree<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()), int32_t(topk),
|
||||
int32_t(depth), int32_t(draft_token_num));
|
||||
}
|
||||
120
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu
Normal file
120
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu
Normal file
@@ -0,0 +1,120 @@
|
||||
/*
|
||||
* Copyright (c) 2025 by SGLang team.
|
||||
* Copyright (c) 2025 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <speculative_sampling.cuh>
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
// predicts: [tot_num_draft_tokens]
|
||||
// accept_index: [bs, num_spec_step]
|
||||
// accept_token_num: [bs]
|
||||
// candidates: [bs, num_draft_tokens]
|
||||
// retrive_index: [bs, num_draft_tokens]
|
||||
// retrive_next_token: [bs, num_draft_tokens]
|
||||
// retrive_next_sibling: [bs, num_draft_tokens]
|
||||
// uniform_samples: [bs, num_draft_tokens]
|
||||
// target_probs: [bs, num_draft_tokens, vocab_size]
|
||||
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index,
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates, at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs,
|
||||
bool deterministic, int64_t cuda_stream = 0) {
|
||||
CHECK_INPUT(candidates);
|
||||
CHECK_INPUT(retrive_index);
|
||||
CHECK_INPUT(retrive_next_token);
|
||||
CHECK_INPUT(retrive_next_sibling);
|
||||
CHECK_INPUT(uniform_samples);
|
||||
CHECK_INPUT(target_probs);
|
||||
auto device = target_probs.device();
|
||||
CHECK_EQ(candidates.device(), device);
|
||||
CHECK_EQ(retrive_index.device(), device);
|
||||
CHECK_EQ(retrive_next_token.device(), device);
|
||||
CHECK_EQ(retrive_next_sibling.device(), device);
|
||||
CHECK_EQ(uniform_samples.device(), device);
|
||||
CHECK_EQ(target_probs.device(), device);
|
||||
CHECK_DIM(1, predicts);
|
||||
CHECK_DIM(2, accept_index);
|
||||
CHECK_DIM(1, accept_token_num);
|
||||
CHECK_DIM(2, candidates);
|
||||
CHECK_DIM(2, retrive_index);
|
||||
CHECK_DIM(2, retrive_next_token);
|
||||
CHECK_DIM(2, retrive_next_sibling);
|
||||
CHECK_DIM(2, uniform_samples);
|
||||
CHECK_DIM(3, target_probs);
|
||||
CHECK_DIM(3, draft_probs);
|
||||
unsigned int batch_size = uniform_samples.size(0);
|
||||
unsigned int num_spec_step = accept_index.size(1);
|
||||
unsigned int num_draft_tokens = candidates.size(1);
|
||||
unsigned int vocab_size = target_probs.size(2);
|
||||
CHECK_EQ(batch_size, candidates.size(0));
|
||||
CHECK_EQ(batch_size, retrive_index.size(0));
|
||||
CHECK_EQ(batch_size, retrive_next_token.size(0));
|
||||
CHECK_EQ(batch_size, retrive_next_sibling.size(0));
|
||||
CHECK_EQ(batch_size, target_probs.size(0));
|
||||
CHECK_EQ(num_draft_tokens, retrive_index.size(1));
|
||||
CHECK_EQ(num_draft_tokens, retrive_next_token.size(1));
|
||||
CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1));
|
||||
CHECK_EQ(num_draft_tokens, uniform_samples.size(1));
|
||||
CHECK_EQ(num_draft_tokens, target_probs.size(1));
|
||||
CHECK_EQ(vocab_size, target_probs.size(2));
|
||||
CHECK_EQ(batch_size, accept_index.size(0));
|
||||
CHECK_EQ(batch_size, accept_token_num.size(0));
|
||||
if (predicts.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32).");
|
||||
}
|
||||
if (accept_index.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32).");
|
||||
}
|
||||
if (accept_token_num.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
||||
}
|
||||
if (candidates.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
|
||||
}
|
||||
if (retrive_index.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
|
||||
}
|
||||
if (retrive_next_token.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
|
||||
}
|
||||
if (retrive_next_sibling.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
|
||||
}
|
||||
if (uniform_samples.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32).");
|
||||
}
|
||||
if (target_probs.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
||||
}
|
||||
if (draft_probs.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
||||
}
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
|
||||
static_cast<int*>(predicts.data_ptr()), static_cast<int*>(accept_index.data_ptr()),
|
||||
static_cast<int*>(accept_token_num.data_ptr()), static_cast<int*>(candidates.data_ptr()),
|
||||
static_cast<int*>(retrive_index.data_ptr()), static_cast<int*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int*>(retrive_next_sibling.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
|
||||
static_cast<float*>(target_probs.data_ptr()), static_cast<float*>(draft_probs.data_ptr()), batch_size,
|
||||
num_spec_step, num_draft_tokens, vocab_size, deterministic, stream);
|
||||
|
||||
TORCH_CHECK(status == cudaSuccess,
|
||||
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
}
|
||||
184
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh
Normal file
184
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh
Normal file
@@ -0,0 +1,184 @@
|
||||
/*
|
||||
* Copyright (c) 2025 by SGLang team.
|
||||
* Copyright (c) 2024-2025 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef SPECULATIVE_SAMPLING_CUH_
|
||||
#define SPECULATIVE_SAMPLING_CUH_
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <flashinfer/sampling.cuh>
|
||||
|
||||
namespace flashinfer {
|
||||
|
||||
namespace sampling {
|
||||
|
||||
using namespace cub;
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM, BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE, bool DETERMINISTIC, typename DType, typename IdType>
|
||||
__global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* accept_index,
|
||||
IdType* accept_token_num, // mutable
|
||||
IdType* candidates, IdType* retrive_index, IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling, DType* uniform_samples,
|
||||
DType* target_probs, DType* draft_probs, uint32_t batch_size,
|
||||
uint32_t num_speculative_tokens, uint32_t num_draft_tokens,
|
||||
uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
|
||||
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
|
||||
|
||||
DType prob_acc = 0.0;
|
||||
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
||||
DType coin = uniform_samples[bx * num_draft_tokens];
|
||||
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
||||
uint32_t num_accepted_tokens = 0;
|
||||
IdType cur_index = 0;
|
||||
|
||||
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
||||
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
||||
while (cur_index != -1) {
|
||||
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||
prob_acc += target_probs[cur_prob_offset + draft_token_id];
|
||||
|
||||
if (coin < prob_acc) {
|
||||
// accept token
|
||||
prob_acc = 0.;
|
||||
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
|
||||
coin = uniform_samples[bx * num_draft_tokens + cur_index];
|
||||
predicts[last_accepted_retrive_idx] = draft_token_id;
|
||||
++num_accepted_tokens;
|
||||
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
|
||||
last_accepted_retrive_idx = draft_index;
|
||||
break;
|
||||
} else {
|
||||
// FIXME: leverage draft probs
|
||||
draft_probs[cur_prob_offset + draft_token_id] = target_probs[cur_prob_offset + draft_token_id];
|
||||
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
|
||||
}
|
||||
}
|
||||
if (cur_index == -1) break;
|
||||
}
|
||||
accept_token_num[bx] = num_accepted_tokens;
|
||||
|
||||
// sample from relu(target_probs - draft_probs)
|
||||
DType sum_relu_q_minus_p(0);
|
||||
vec_t<DType, VEC_SIZE> q_vec, p_vec;
|
||||
DType relu_q_minus_p[VEC_SIZE];
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
q_vec.fill(DType(0));
|
||||
p_vec.fill(DType(0));
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
if (num_accepted_tokens != num_speculative_tokens - 1) {
|
||||
// there is no draft_probs for the bonus token
|
||||
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0));
|
||||
}
|
||||
sum_relu_q_minus_p += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(relu_q_minus_p);
|
||||
__syncthreads();
|
||||
}
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.value = sum_relu_q_minus_p;
|
||||
}
|
||||
// init the first rejected token to (d - 1)
|
||||
temp_storage.sampled_id = d - 1;
|
||||
__syncthreads();
|
||||
sum_relu_q_minus_p = temp_storage.block_aggregate.value;
|
||||
DType u = coin * sum_relu_q_minus_p;
|
||||
|
||||
DType aggregate_relu_q_minus_p(0);
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
q_vec.fill(DType(0));
|
||||
p_vec.fill(DType(0));
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
if (num_accepted_tokens != num_speculative_tokens - 1) {
|
||||
// there is no draft_probs for the bonus token
|
||||
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
vec_t<DType, VEC_SIZE> relu_q_minus_p_vec;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
|
||||
}
|
||||
|
||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
|
||||
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
|
||||
if (aggregate_relu_q_minus_p > u) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// set the first rejected token
|
||||
predicts[last_accepted_retrive_idx] = temp_storage.sampled_id;
|
||||
// value at not used indices are undefined
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* output_token_ids,
|
||||
IdType* output_accepted_token_num, // mutable
|
||||
IdType* candidates, IdType* retrive_index, IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs,
|
||||
DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens, uint32_t d, bool deterministic,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&predicts,
|
||||
&output_token_ids,
|
||||
&output_accepted_token_num,
|
||||
&candidates,
|
||||
&retrive_index,
|
||||
&retrive_next_token,
|
||||
&retrive_next_sibling,
|
||||
&uniform_samples,
|
||||
&target_probs,
|
||||
&draft_probs,
|
||||
&batch_size,
|
||||
&num_speculative_tokens,
|
||||
&num_draft_tokens,
|
||||
&d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel = TreeSpeculativeSamplingTargetOnly<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, DETERMINISTIC,
|
||||
DType, IdType>;
|
||||
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
|
||||
} // namespace flashinfer
|
||||
|
||||
#endif // SPECULATIVE_SAMPLING_CUH_
|
||||
@@ -127,3 +127,19 @@ void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
|
||||
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave,
|
||||
int64_t cuda_stream);
|
||||
|
||||
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index,
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates, at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs,
|
||||
bool deterministic = true, int64_t cuda_stream = 0);
|
||||
|
||||
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num);
|
||||
|
||||
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num);
|
||||
|
||||
@@ -495,3 +495,87 @@ def min_p_sampling_from_probs(
|
||||
return _min_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
|
||||
)
|
||||
|
||||
|
||||
def tree_speculative_sampling_target_only(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
with predicts.device as device:
|
||||
torch.ops.sgl_kernels.tree_speculative_sampling_target_only(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
deterministic,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
with parent_list.device as device:
|
||||
torch.ops.sgl_kernels.build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
with parent_list.device as device:
|
||||
torch.ops.sgl_kernels.build_tree_kernel(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
@@ -130,6 +130,29 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
// tree spec decode
|
||||
m.def(
|
||||
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
|
||||
"bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
||||
|
||||
// eagle build tree
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
|
||||
"retrive_next_sibling, "
|
||||
"int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
|
||||
// eagle build tree
|
||||
m.def(
|
||||
"build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
|
||||
"int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_kernels)
|
||||
|
||||
104
sgl-kernel/tests/test_speculative_sampling.py
Normal file
104
sgl-kernel/tests/test_speculative_sampling.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import tree_speculative_sampling_target_only
|
||||
|
||||
|
||||
def test_tree_speculative_sampling_target_only():
|
||||
candidates = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_index = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10, 11],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_token = torch.tensor(
|
||||
[
|
||||
[1, 2, -1, 4, 5, -1],
|
||||
[4, 2, 3, -1, 5, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_sibling = torch.tensor(
|
||||
[
|
||||
[-1, 3, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, 1, -1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
target_logits = torch.zeros((2, 6, 20), dtype=torch.float32, device="cuda")
|
||||
target_logits[0, 0, 3] = 10
|
||||
target_logits[0, 3, 4] = 10
|
||||
target_logits[0, 4, 5] = 10
|
||||
target_logits[1, 0, 11] = 10
|
||||
target_logits[1, 4, 12] = 10
|
||||
for i in range(target_logits.shape[0]):
|
||||
for j in range(target_logits.shape[1]):
|
||||
if torch.max(target_logits[i][j]) < 10:
|
||||
target_logits[i][j][18] = 10
|
||||
|
||||
temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device="cuda")
|
||||
predict_shape = (12,)
|
||||
|
||||
bs = candidates.shape[0]
|
||||
num_spec_step = 4
|
||||
num_draft_tokens = candidates.shape[1]
|
||||
|
||||
predicts = torch.full(
|
||||
predict_shape, -1, dtype=torch.int32, device="cuda"
|
||||
) # mutable
|
||||
accept_index = torch.full(
|
||||
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
|
||||
) # mutable
|
||||
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
|
||||
|
||||
expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1)
|
||||
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
|
||||
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device="cuda")
|
||||
|
||||
coins = torch.rand(bs, num_draft_tokens, device="cuda").to(torch.float32)
|
||||
print(f"{candidates=}")
|
||||
print(f"{retrive_index=}")
|
||||
print(f"{retrive_next_token=}")
|
||||
print(f"{retrive_next_sibling=}")
|
||||
print(f"{coins=}")
|
||||
|
||||
tree_speculative_sampling_target_only(
|
||||
predicts=predicts,
|
||||
accept_index=accept_index,
|
||||
accept_token_num=accept_token_num,
|
||||
candidates=candidates,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
uniform_samples=coins,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
print(f"{predicts=}")
|
||||
print(f"{accept_index=}")
|
||||
print(f"{accept_token_num=}")
|
||||
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tree_speculative_sampling_target_only()
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.0.3.post1"
|
||||
__version__ = "0.0.3.post2"
|
||||
|
||||
Reference in New Issue
Block a user