Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 (#7756)
Co-authored-by: Pranjal Shankhdhar <pranjal.ssh@gmail.com>
This commit is contained in:
@@ -1,10 +1,12 @@
|
|||||||
# NOTE: Please run this file to make sure the test cases are correct.
|
# NOTE: Please run this file to make sure the test cases are correct.
|
||||||
|
|
||||||
from typing import List
|
import math
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda, is_hip, rank0_log
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
|
|
||||||
if is_cuda() or is_hip():
|
if is_cuda() or is_hip():
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
|
|||||||
return parent_list, top_scores_index, draft_tokens
|
return parent_list, top_scores_index, draft_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TreeMaskMode(IntEnum):
|
||||||
|
FULL_MASK = 0
|
||||||
|
QLEN_ONLY = 1
|
||||||
|
QLEN_ONLY_BITPACKING = 2
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel_efficient(
|
def build_tree_kernel_efficient(
|
||||||
verified_id: torch.Tensor,
|
verified_id: torch.Tensor,
|
||||||
score_list: List[torch.Tensor],
|
score_list: List[torch.Tensor],
|
||||||
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
|
|||||||
topk: int,
|
topk: int,
|
||||||
spec_steps: int,
|
spec_steps: int,
|
||||||
num_verify_tokens: int,
|
num_verify_tokens: int,
|
||||||
|
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
||||||
|
tree_mask_buf: Optional[torch.Tensor] = None,
|
||||||
|
position_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
parent_list, top_scores_index, draft_tokens = (
|
parent_list, top_scores_index, draft_tokens = (
|
||||||
build_tree_kernel_efficient_preprocess(
|
build_tree_kernel_efficient_preprocess(
|
||||||
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
|
|||||||
device = seq_lens.device
|
device = seq_lens.device
|
||||||
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
||||||
# where each row indicates the attending pattern of each draft token
|
# where each row indicates the attending pattern of each draft token
|
||||||
|
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
||||||
|
if tree_mask_buf is not None:
|
||||||
|
tree_mask = tree_mask_buf
|
||||||
|
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
||||||
|
tree_mask = torch.full(
|
||||||
|
(num_verify_tokens * bs * num_verify_tokens,),
|
||||||
|
True,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
||||||
|
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
||||||
|
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
||||||
|
tree_mask = torch.zeros(
|
||||||
|
(num_verify_tokens * bs,),
|
||||||
|
dtype=packed_dtypes[packed_dtype_idx],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
||||||
|
tree_mask = torch.full(
|
||||||
|
(
|
||||||
|
seq_lens_sum * num_verify_tokens
|
||||||
|
+ num_verify_tokens * num_verify_tokens * bs,
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
||||||
|
|
||||||
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
||||||
tree_mask = torch.full(
|
|
||||||
(
|
|
||||||
seq_lens_sum * num_verify_tokens
|
|
||||||
+ num_verify_tokens * num_verify_tokens * bs,
|
|
||||||
),
|
|
||||||
True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
retrive_index = torch.full(
|
retrive_index = torch.full(
|
||||||
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
||||||
)
|
)
|
||||||
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
|
|||||||
# position: where each token belongs to
|
# position: where each token belongs to
|
||||||
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
||||||
# then, positions = [7, 8, 8, 9]
|
# then, positions = [7, 8, 8, 9]
|
||||||
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
if position_buf is not None:
|
||||||
|
positions = position_buf
|
||||||
|
else:
|
||||||
|
positions = torch.empty(
|
||||||
|
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
sgl_build_tree_kernel_efficient(
|
sgl_build_tree_kernel_efficient(
|
||||||
parent_list,
|
parent_list,
|
||||||
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
|
|||||||
topk,
|
topk,
|
||||||
spec_steps,
|
spec_steps,
|
||||||
num_verify_tokens,
|
num_verify_tokens,
|
||||||
|
tree_mask_mode,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
tree_mask,
|
tree_mask,
|
||||||
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
|
|||||||
num_verify_tokens=num_draft_token,
|
num_verify_tokens=num_draft_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
rank0_log("=========== build tree kernel efficient ==========")
|
print("=========== build tree kernel efficient ==========")
|
||||||
# rank0_log(f"{tree_mask=}")
|
print(f"{tree_mask=}")
|
||||||
rank0_log(f"{position=}")
|
print(f"{position=}")
|
||||||
rank0_log(f"{retrive_index=}")
|
print(f"{retrive_index=}")
|
||||||
rank0_log(f"{retrive_next_token=}")
|
print(f"{retrive_next_token=}")
|
||||||
rank0_log(f"{retrive_next_sibling=}")
|
print(f"{retrive_next_sibling=}")
|
||||||
rank0_log(f"{draft_tokens=}")
|
print(f"{draft_tokens=}")
|
||||||
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
||||||
assert retrive_index.tolist() == [
|
assert retrive_index.tolist() == [
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7],
|
[0, 1, 2, 3, 4, 5, 6, 7],
|
||||||
|
|||||||
@@ -232,7 +232,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
|
||||||
|
"()");
|
||||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
|||||||
@@ -23,6 +23,8 @@
|
|||||||
#include "pytorch_extension_utils_rocm.h"
|
#include "pytorch_extension_utils_rocm.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
|
||||||
|
|
||||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||||
// selected_index [bs, draft_token_num - 1]
|
// selected_index [bs, draft_token_num - 1]
|
||||||
// verified_seq_len [bs]
|
// verified_seq_len [bs]
|
||||||
@@ -40,7 +42,8 @@ __global__ void build_tree_efficient(
|
|||||||
int64_t* retrive_next_sibling,
|
int64_t* retrive_next_sibling,
|
||||||
int topk,
|
int topk,
|
||||||
int depth,
|
int depth,
|
||||||
int draft_token_num) {
|
int draft_token_num,
|
||||||
|
int tree_mask_mode) {
|
||||||
int bid = blockIdx.x;
|
int bid = blockIdx.x;
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
|
|
||||||
@@ -52,7 +55,13 @@ __global__ void build_tree_efficient(
|
|||||||
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
||||||
}
|
}
|
||||||
int seq_len = verified_seq_len[bid];
|
int seq_len = verified_seq_len[bid];
|
||||||
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
int token_tree_idx;
|
||||||
|
if (tree_mask_mode == FULL_MASK) {
|
||||||
|
token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
||||||
|
} else {
|
||||||
|
token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1;
|
||||||
|
}
|
||||||
|
tree_mask[token_tree_idx - 1] = true;
|
||||||
for (int i = 0; i < draft_token_num - 1; i++) {
|
for (int i = 0; i < draft_token_num - 1; i++) {
|
||||||
tree_mask[token_tree_idx + i] = false;
|
tree_mask[token_tree_idx + i] = false;
|
||||||
}
|
}
|
||||||
@@ -124,7 +133,8 @@ void build_tree_kernel_efficient(
|
|||||||
at::Tensor retrive_next_sibling,
|
at::Tensor retrive_next_sibling,
|
||||||
int64_t topk,
|
int64_t topk,
|
||||||
int64_t depth,
|
int64_t depth,
|
||||||
int64_t draft_token_num) {
|
int64_t draft_token_num,
|
||||||
|
int64_t tree_mask_mode) {
|
||||||
// TODO (ying) check shape
|
// TODO (ying) check shape
|
||||||
// TODO (ying) check type
|
// TODO (ying) check type
|
||||||
int bs = parent_list.size(0);
|
int bs = parent_list.size(0);
|
||||||
@@ -132,18 +142,29 @@ void build_tree_kernel_efficient(
|
|||||||
dim3 block(draft_token_num);
|
dim3 block(draft_token_num);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
if (tree_mask_mode == QLEN_ONLY_BITPACKING) {
|
||||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
size_t num_bytes_per_item = 1;
|
||||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
if (draft_token_num > 16) {
|
||||||
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
num_bytes_per_item = 4;
|
||||||
static_cast<bool*>(tree_mask.data_ptr()),
|
} else if (draft_token_num > 8) {
|
||||||
static_cast<int64_t*>(positions.data_ptr()),
|
num_bytes_per_item = 2;
|
||||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
}
|
||||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
throw std::runtime_error("Not implemented");
|
||||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
} else {
|
||||||
int32_t(topk),
|
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||||
int32_t(depth),
|
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||||
int32_t(draft_token_num));
|
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||||
|
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||||
|
static_cast<bool*>(tree_mask.data_ptr()),
|
||||||
|
static_cast<int64_t*>(positions.data_ptr()),
|
||||||
|
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||||
|
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||||
|
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||||
|
int32_t(topk),
|
||||||
|
int32_t(depth),
|
||||||
|
int32_t(draft_token_num),
|
||||||
|
int32_t(tree_mask_mode));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename IdType, typename IdType2>
|
template <typename IdType, typename IdType2>
|
||||||
|
|||||||
@@ -78,7 +78,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
|
||||||
|
"()");
|
||||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -374,7 +374,8 @@ void build_tree_kernel_efficient(
|
|||||||
at::Tensor retrive_next_sibling,
|
at::Tensor retrive_next_sibling,
|
||||||
int64_t topk,
|
int64_t topk,
|
||||||
int64_t depth,
|
int64_t depth,
|
||||||
int64_t draft_token_num);
|
int64_t draft_token_num,
|
||||||
|
int64_t tree_mask_mode);
|
||||||
|
|
||||||
void segment_packbits(
|
void segment_packbits(
|
||||||
at::Tensor x,
|
at::Tensor x,
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ def build_tree_kernel_efficient(
|
|||||||
topk: int,
|
topk: int,
|
||||||
depth: int,
|
depth: int,
|
||||||
draft_token_num: int,
|
draft_token_num: int,
|
||||||
|
tree_mask_mode: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
|
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
|
||||||
parent_list,
|
parent_list,
|
||||||
@@ -85,6 +86,7 @@ def build_tree_kernel_efficient(
|
|||||||
topk,
|
topk,
|
||||||
depth,
|
depth,
|
||||||
draft_token_num,
|
draft_token_num,
|
||||||
|
tree_mask_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user