[Feature] Speculative decoding support lookahead (#9873)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -318,6 +318,7 @@ set(SOURCES
|
||||
"csrc/kvcacheio/transfer.cu"
|
||||
|
||||
"csrc/speculative/eagle_utils.cu"
|
||||
"csrc/speculative/lookahead_utils.cu"
|
||||
"csrc/speculative/packbit.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
|
||||
|
||||
@@ -291,6 +291,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"Tensor target_predict, int cuda_stream) -> ()");
|
||||
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
|
||||
|
||||
m.def(
|
||||
"reconstruct_indices_from_tree_mask(Tensor tree_mask, Tensor verified_seq_len, Tensor positions, "
|
||||
"Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"int batch_size, int draft_token_num) -> ()");
|
||||
m.impl("reconstruct_indices_from_tree_mask", torch::kCUDA, &reconstruct_indices_from_tree_mask);
|
||||
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||
|
||||
105
sgl-kernel/csrc/speculative/lookahead_utils.cu
Normal file
105
sgl-kernel/csrc/speculative/lookahead_utils.cu
Normal file
@@ -0,0 +1,105 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include "pytorch_extension_utils.h"
|
||||
#else
|
||||
#include "pytorch_extension_utils_rocm.h"
|
||||
#endif
|
||||
|
||||
// tree_mask: [bs * draft_token_num * draft_token_num]
|
||||
// verified_seq_len: [bs]
|
||||
// positions: [bs * draft_token_num]
|
||||
// retrive_index: [bs, draft_token_num]
|
||||
// retrive_next_token: [bs, draft_token_num]
|
||||
// retrive_next_sibling: [bs, draft_token_num]
|
||||
__global__ void reconstructIndicesFromTreeMask(
|
||||
bool* tree_mask,
|
||||
int64_t* verified_seq_len,
|
||||
int64_t* positions,
|
||||
int64_t* retrive_index,
|
||||
int64_t* retrive_next_token,
|
||||
int64_t* retrive_next_sibling,
|
||||
int batch_size,
|
||||
int draft_token_num) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (bid >= batch_size || tid >= draft_token_num) {
|
||||
return;
|
||||
}
|
||||
int base_offset = draft_token_num * draft_token_num;
|
||||
// token_idx: [bid * draft_token_num, (bid + 1) * draft_token_num)
|
||||
int token_idx = bid * draft_token_num;
|
||||
// tree_mask_idx: [bid * base_offset, (bid + 1) * base_offset)
|
||||
int tree_mask_offset = bid * base_offset;
|
||||
|
||||
int depth = 0;
|
||||
int parent_idx = -1;
|
||||
|
||||
for (int i = tid - 1, start_idx = tree_mask_offset + tid * draft_token_num; i >= 0; i--) {
|
||||
if (tree_mask[start_idx + i]) {
|
||||
depth++;
|
||||
if (parent_idx == -1) {
|
||||
parent_idx = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
retrive_index[token_idx + tid] = token_idx + tid;
|
||||
positions[token_idx + tid] = depth + verified_seq_len[bid];
|
||||
|
||||
int next_token_idx = -1;
|
||||
for (int i = tid + 1; i < draft_token_num; i++) {
|
||||
if (tree_mask[tree_mask_offset + i * draft_token_num + tid]) {
|
||||
next_token_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
retrive_next_token[token_idx + tid] = next_token_idx;
|
||||
|
||||
int next_sibling_idx = -1;
|
||||
if (parent_idx != -1) {
|
||||
for (int i = tid + 1; i < draft_token_num; i++) {
|
||||
int start_idx = tree_mask_offset + i * draft_token_num + parent_idx;
|
||||
if (tree_mask[start_idx]) {
|
||||
bool is_sibling = true;
|
||||
int end_idx = tree_mask_offset + i * draft_token_num + i;
|
||||
for (int j = start_idx + 1; j < end_idx; ++j) {
|
||||
if (tree_mask[j]) {
|
||||
is_sibling = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_sibling) {
|
||||
next_sibling_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
retrive_next_sibling[token_idx + tid] = next_sibling_idx;
|
||||
}
|
||||
|
||||
void reconstruct_indices_from_tree_mask(
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
int64_t batch_size,
|
||||
int64_t draft_token_num) {
|
||||
dim3 grid(batch_size);
|
||||
dim3 block(draft_token_num);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
reconstructIndicesFromTreeMask<<<grid, block, 0, stream>>>(
|
||||
static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int(batch_size),
|
||||
int(draft_token_num));
|
||||
}
|
||||
@@ -457,6 +457,16 @@ void verify_tree_greedy(
|
||||
at::Tensor target_predict,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
void reconstruct_indices_from_tree_mask(
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor positions, // mutable
|
||||
at::Tensor retrive_index, // mutable
|
||||
at::Tensor retrive_next_token, // mutable
|
||||
at::Tensor retrive_next_sibling, // mutable
|
||||
int64_t batch_size,
|
||||
int64_t draft_token_num);
|
||||
|
||||
void build_tree_kernel_efficient(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
|
||||
@@ -126,6 +126,7 @@ from sgl_kernel.sampling import (
|
||||
)
|
||||
from sgl_kernel.speculative import (
|
||||
build_tree_kernel_efficient,
|
||||
reconstruct_indices_from_tree_mask,
|
||||
segment_packbits,
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
|
||||
@@ -90,6 +90,28 @@ def build_tree_kernel_efficient(
|
||||
)
|
||||
|
||||
|
||||
def reconstruct_indices_from_tree_mask(
|
||||
tree_mask: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
batch_size: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.reconstruct_indices_from_tree_mask.default(
|
||||
tree_mask,
|
||||
verified_seq_len,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
batch_size,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
|
||||
def segment_packbits(
|
||||
x: torch.Tensor,
|
||||
input_indptr: torch.Tensor,
|
||||
|
||||
76
sgl-kernel/tests/speculative/test_lookahead_utils.py
Normal file
76
sgl-kernel/tests/speculative/test_lookahead_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import reconstruct_indices_from_tree_mask
|
||||
|
||||
|
||||
def test_reconstruct_indices_from_tree_mask():
|
||||
bs = 1
|
||||
num_branch_token = 4
|
||||
seq_lens = torch.tensor([12], device="cuda", dtype=torch.int64)
|
||||
|
||||
retrive_index = torch.full(
|
||||
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
|
||||
)
|
||||
retrive_next_token = torch.full(
|
||||
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
|
||||
)
|
||||
retrive_next_sibling = torch.full(
|
||||
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
|
||||
)
|
||||
positions = torch.empty((bs * num_branch_token), device="cuda", dtype=torch.int64)
|
||||
|
||||
tree_mask = torch.tensor(
|
||||
[
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
).to(torch.bool)
|
||||
|
||||
reconstruct_indices_from_tree_mask(
|
||||
tree_mask,
|
||||
seq_lens,
|
||||
positions, # mutable
|
||||
retrive_index, # mutable
|
||||
retrive_next_token, # mutable
|
||||
retrive_next_sibling, # mutable
|
||||
bs,
|
||||
num_branch_token,
|
||||
)
|
||||
# print(f"debug: \n\n{tree_mask=}, {retrive_index=}, {retrive_next_token=}, {retrive_next_sibling=}, {positions=}\n\n")
|
||||
assert retrive_index.tolist() == [
|
||||
[0, 1, 2, 3],
|
||||
], f"{retrive_index=}"
|
||||
assert retrive_next_token.tolist() == [
|
||||
[1, -1, 3, -1],
|
||||
], f"{retrive_next_token=}"
|
||||
assert retrive_next_sibling.tolist() == [
|
||||
[-1, 2, -1, -1],
|
||||
], f"{retrive_next_sibling=}"
|
||||
assert positions.tolist() == [
|
||||
12,
|
||||
13,
|
||||
13,
|
||||
14,
|
||||
], f"{positions=}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_reconstruct_indices_from_tree_mask()
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user