[Feature] Speculative decoding support lookahead (#9873)

Co-authored-by: a4zhangfei <a4zhangfei@qq.com>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
Zhihao Zhang
2025-09-19 07:42:41 +08:00
committed by GitHub
parent 2a2ff9a840
commit e7bc600304
30 changed files with 2058 additions and 32 deletions

View File

@@ -318,6 +318,7 @@ set(SOURCES
"csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/lookahead_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"

View File

@@ -291,6 +291,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor target_predict, int cuda_stream) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
"reconstruct_indices_from_tree_mask(Tensor tree_mask, Tensor verified_seq_len, Tensor positions, "
"Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"int batch_size, int draft_token_num) -> ()");
m.impl("reconstruct_indices_from_tree_mask", torch::kCUDA, &reconstruct_indices_from_tree_mask);
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "

View File

@@ -0,0 +1,105 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#endif
// tree_mask: [bs * draft_token_num * draft_token_num]
// verified_seq_len: [bs]
// positions: [bs * draft_token_num]
// retrive_index: [bs, draft_token_num]
// retrive_next_token: [bs, draft_token_num]
// retrive_next_sibling: [bs, draft_token_num]
__global__ void reconstructIndicesFromTreeMask(
bool* tree_mask,
int64_t* verified_seq_len,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int batch_size,
int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (bid >= batch_size || tid >= draft_token_num) {
return;
}
int base_offset = draft_token_num * draft_token_num;
// token_idx: [bid * draft_token_num, (bid + 1) * draft_token_num)
int token_idx = bid * draft_token_num;
// tree_mask_idx: [bid * base_offset, (bid + 1) * base_offset)
int tree_mask_offset = bid * base_offset;
int depth = 0;
int parent_idx = -1;
for (int i = tid - 1, start_idx = tree_mask_offset + tid * draft_token_num; i >= 0; i--) {
if (tree_mask[start_idx + i]) {
depth++;
if (parent_idx == -1) {
parent_idx = i;
}
}
}
retrive_index[token_idx + tid] = token_idx + tid;
positions[token_idx + tid] = depth + verified_seq_len[bid];
int next_token_idx = -1;
for (int i = tid + 1; i < draft_token_num; i++) {
if (tree_mask[tree_mask_offset + i * draft_token_num + tid]) {
next_token_idx = i;
break;
}
}
retrive_next_token[token_idx + tid] = next_token_idx;
int next_sibling_idx = -1;
if (parent_idx != -1) {
for (int i = tid + 1; i < draft_token_num; i++) {
int start_idx = tree_mask_offset + i * draft_token_num + parent_idx;
if (tree_mask[start_idx]) {
bool is_sibling = true;
int end_idx = tree_mask_offset + i * draft_token_num + i;
for (int j = start_idx + 1; j < end_idx; ++j) {
if (tree_mask[j]) {
is_sibling = false;
break;
}
}
if (is_sibling) {
next_sibling_idx = i;
break;
}
}
}
}
retrive_next_sibling[token_idx + tid] = next_sibling_idx;
}
void reconstruct_indices_from_tree_mask(
at::Tensor tree_mask,
at::Tensor verified_seq_len,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t batch_size,
int64_t draft_token_num) {
dim3 grid(batch_size);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstructIndicesFromTreeMask<<<grid, block, 0, stream>>>(
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int(batch_size),
int(draft_token_num));
}

View File

@@ -457,6 +457,16 @@ void verify_tree_greedy(
at::Tensor target_predict,
int64_t cuda_stream = 0);
void reconstruct_indices_from_tree_mask(
at::Tensor tree_mask,
at::Tensor verified_seq_len,
at::Tensor positions, // mutable
at::Tensor retrive_index, // mutable
at::Tensor retrive_next_token, // mutable
at::Tensor retrive_next_sibling, // mutable
int64_t batch_size,
int64_t draft_token_num);
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,

View File

@@ -126,6 +126,7 @@ from sgl_kernel.sampling import (
)
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
reconstruct_indices_from_tree_mask,
segment_packbits,
tree_speculative_sampling_target_only,
verify_tree_greedy,

View File

@@ -90,6 +90,28 @@ def build_tree_kernel_efficient(
)
def reconstruct_indices_from_tree_mask(
tree_mask: torch.Tensor,
verified_seq_len: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
batch_size: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernel.reconstruct_indices_from_tree_mask.default(
tree_mask,
verified_seq_len,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
batch_size,
draft_token_num,
)
def segment_packbits(
x: torch.Tensor,
input_indptr: torch.Tensor,

View File

@@ -0,0 +1,76 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import reconstruct_indices_from_tree_mask
def test_reconstruct_indices_from_tree_mask():
bs = 1
num_branch_token = 4
seq_lens = torch.tensor([12], device="cuda", dtype=torch.int64)
retrive_index = torch.full(
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
)
retrive_next_token = torch.full(
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
)
retrive_next_sibling = torch.full(
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
)
positions = torch.empty((bs * num_branch_token), device="cuda", dtype=torch.int64)
tree_mask = torch.tensor(
[
1,
0,
0,
0,
1,
1,
0,
0,
1,
0,
1,
0,
1,
0,
1,
1,
],
device="cuda",
dtype=torch.int32,
).to(torch.bool)
reconstruct_indices_from_tree_mask(
tree_mask,
seq_lens,
positions, # mutable
retrive_index, # mutable
retrive_next_token, # mutable
retrive_next_sibling, # mutable
bs,
num_branch_token,
)
# print(f"debug: \n\n{tree_mask=}, {retrive_index=}, {retrive_next_token=}, {retrive_next_sibling=}, {positions=}\n\n")
assert retrive_index.tolist() == [
[0, 1, 2, 3],
], f"{retrive_index=}"
assert retrive_next_token.tolist() == [
[1, -1, 3, -1],
], f"{retrive_next_token=}"
assert retrive_next_sibling.tolist() == [
[-1, 2, -1, -1],
], f"{retrive_next_sibling=}"
assert positions.tolist() == [
12,
13,
13,
14,
], f"{positions=}"
if __name__ == "__main__":
test_reconstruct_indices_from_tree_mask()
pytest.main([__file__])