Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 (#7756)
Co-authored-by: Pranjal Shankhdhar <pranjal.ssh@gmail.com>
This commit is contained in:
@@ -232,7 +232,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
|
||||
"()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
|
||||
m.def(
|
||||
|
||||
@@ -23,6 +23,8 @@
|
||||
#include "pytorch_extension_utils_rocm.h"
|
||||
#endif
|
||||
|
||||
typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
// verified_seq_len [bs]
|
||||
@@ -40,7 +42,8 @@ __global__ void build_tree_efficient(
|
||||
int64_t* retrive_next_sibling,
|
||||
int topk,
|
||||
int depth,
|
||||
int draft_token_num) {
|
||||
int draft_token_num,
|
||||
int tree_mask_mode) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
@@ -52,7 +55,13 @@ __global__ void build_tree_efficient(
|
||||
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
||||
}
|
||||
int seq_len = verified_seq_len[bid];
|
||||
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
||||
int token_tree_idx;
|
||||
if (tree_mask_mode == FULL_MASK) {
|
||||
token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
||||
} else {
|
||||
token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1;
|
||||
}
|
||||
tree_mask[token_tree_idx - 1] = true;
|
||||
for (int i = 0; i < draft_token_num - 1; i++) {
|
||||
tree_mask[token_tree_idx + i] = false;
|
||||
}
|
||||
@@ -124,7 +133,8 @@ void build_tree_kernel_efficient(
|
||||
at::Tensor retrive_next_sibling,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num) {
|
||||
int64_t draft_token_num,
|
||||
int64_t tree_mask_mode) {
|
||||
// TODO (ying) check shape
|
||||
// TODO (ying) check type
|
||||
int bs = parent_list.size(0);
|
||||
@@ -132,18 +142,29 @@ void build_tree_kernel_efficient(
|
||||
dim3 block(draft_token_num);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int32_t(topk),
|
||||
int32_t(depth),
|
||||
int32_t(draft_token_num));
|
||||
if (tree_mask_mode == QLEN_ONLY_BITPACKING) {
|
||||
size_t num_bytes_per_item = 1;
|
||||
if (draft_token_num > 16) {
|
||||
num_bytes_per_item = 4;
|
||||
} else if (draft_token_num > 8) {
|
||||
num_bytes_per_item = 2;
|
||||
}
|
||||
throw std::runtime_error("Not implemented");
|
||||
} else {
|
||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int32_t(topk),
|
||||
int32_t(depth),
|
||||
int32_t(draft_token_num),
|
||||
int32_t(tree_mask_mode));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IdType, typename IdType2>
|
||||
|
||||
@@ -78,7 +78,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
|
||||
"()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
}
|
||||
|
||||
|
||||
@@ -374,7 +374,8 @@ void build_tree_kernel_efficient(
|
||||
at::Tensor retrive_next_sibling,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num);
|
||||
int64_t draft_token_num,
|
||||
int64_t tree_mask_mode);
|
||||
|
||||
void segment_packbits(
|
||||
at::Tensor x,
|
||||
|
||||
@@ -72,6 +72,7 @@ def build_tree_kernel_efficient(
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
tree_mask_mode: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
|
||||
parent_list,
|
||||
@@ -85,6 +86,7 @@ def build_tree_kernel_efficient(
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
tree_mask_mode,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user