Eagle speculative decoding part 4: Add EAGLE2 worker (#2150)
Co-authored-by: kavioyu <kavioyu@tencent.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
37
examples/runtime/engine/EAGLE_offline_batch_inference.py
Normal file
37
examples/runtime/engine/EAGLE_offline_batch_inference.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = {"temperature": 0, "max_new_tokens": 30}
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
llm = sgl.Engine(
|
||||||
|
model_path="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
speculative_algorithm="EAGLE",
|
||||||
|
speculative_draft_model_path="lmzheng/sglang-EAGLE-llama2-chat-7B",
|
||||||
|
speculative_num_steps=3,
|
||||||
|
speculative_eagle_topk=4,
|
||||||
|
speculative_num_draft_tokens=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
print("===============================")
|
||||||
|
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
||||||
|
|
||||||
|
|
||||||
|
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||||
|
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
347
python/sglang/srt/speculative/build_eagle_tree.py
Normal file
347
python/sglang/srt/speculative/build_eagle_tree.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
import cutex
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# parent_table [bs,topk*depth+)]
|
||||||
|
# selected_index [bs,draft_token_num-1)]
|
||||||
|
# verified_seq_len [bs]
|
||||||
|
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
|
||||||
|
# positions [bs*draft_token]
|
||||||
|
# retrive_index [b, draft_token, depth+2]
|
||||||
|
kernels = cutex.SourceModule(
|
||||||
|
"""
|
||||||
|
//cuda
|
||||||
|
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
|
||||||
|
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
|
||||||
|
int bid = blockIdx.x;
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
if (tid >= draft_token_num){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
||||||
|
for(int i=0; i<bid; i++){
|
||||||
|
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
||||||
|
}
|
||||||
|
int seq_len = verified_seq_len[bid];
|
||||||
|
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
|
||||||
|
for(int i=0; i<draft_token_num-1; i++){
|
||||||
|
tree_mask[token_tree_idx+i] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int position = 0;
|
||||||
|
if (tid==0){
|
||||||
|
positions[bid*draft_token_num] = seq_len;
|
||||||
|
retrive_index[bid][0][0] = bid * draft_token_num;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int depends_order[10];
|
||||||
|
|
||||||
|
int cur_position = tid-1;
|
||||||
|
while(true){
|
||||||
|
depends_order[position] = cur_position+1;
|
||||||
|
position += 1;
|
||||||
|
tree_mask[token_tree_idx+cur_position] = true;
|
||||||
|
int parent_tb_idx = selected_index[bid][cur_position]/topk;
|
||||||
|
if(parent_tb_idx==0){
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int token_idx = parent_list[bid][parent_tb_idx];
|
||||||
|
for(cur_position=0; cur_position<draft_token_num;cur_position++){
|
||||||
|
if(selected_index[bid][cur_position]==token_idx){
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
positions[bid*draft_token_num+tid] = position + seq_len;
|
||||||
|
|
||||||
|
int is_leaf = 0;
|
||||||
|
for(int i=1;i<draft_token_num;i++){
|
||||||
|
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
|
||||||
|
{
|
||||||
|
is_leaf ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(is_leaf==1){
|
||||||
|
for(int i=0; i<position; i++){
|
||||||
|
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
|
||||||
|
}
|
||||||
|
retrive_index[bid][tid][0] = bid*draft_token_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
//!cuda
|
||||||
|
""",
|
||||||
|
float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
|
||||||
|
boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
|
||||||
|
bs = seq_lens.numel()
|
||||||
|
device = parent_list.device
|
||||||
|
tree_mask = torch.full(
|
||||||
|
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
|
||||||
|
True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
retrive_index = torch.full(
|
||||||
|
(bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
|
||||||
|
)
|
||||||
|
positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
kernels.build_tree(
|
||||||
|
parent_list,
|
||||||
|
top_score_index,
|
||||||
|
seq_lens.to(torch.int32),
|
||||||
|
tree_mask,
|
||||||
|
positions,
|
||||||
|
retrive_index,
|
||||||
|
topk,
|
||||||
|
depth,
|
||||||
|
draft_token,
|
||||||
|
grid=(bs, 1, 1),
|
||||||
|
block=(64, 1, 1),
|
||||||
|
)
|
||||||
|
index = retrive_index.sum(dim=-1) != -depth - 2
|
||||||
|
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
||||||
|
retrive_cum_len = torch.zeros(
|
||||||
|
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
retrive_cum_len[1:] = cum_len
|
||||||
|
retrive_index = retrive_index[index]
|
||||||
|
return tree_mask, positions, retrive_index, retrive_cum_len
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def findp(p_i, index, parent_list):
|
||||||
|
pos = index // 10
|
||||||
|
index_list = index.tolist()
|
||||||
|
parent_list = parent_list.tolist()
|
||||||
|
res = [p_i]
|
||||||
|
while True:
|
||||||
|
p = pos[p_i]
|
||||||
|
if p == 0:
|
||||||
|
break
|
||||||
|
token_idx = parent_list[p]
|
||||||
|
p_i = index_list.index(token_idx)
|
||||||
|
res.append(p_i)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
|
||||||
|
mask = []
|
||||||
|
positions = []
|
||||||
|
retrive_index = []
|
||||||
|
for i, lens in enumerate(seq_len.tolist()):
|
||||||
|
first_mask = torch.full((lens + draft_token,), True)
|
||||||
|
first_mask[-(draft_token - 1) :] = False
|
||||||
|
positions.append(lens)
|
||||||
|
mask.append(first_mask)
|
||||||
|
seq_order = []
|
||||||
|
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
|
||||||
|
r_index = [first_index]
|
||||||
|
for j in range(draft_token - 1):
|
||||||
|
mask.append(torch.full((lens + 1,), True))
|
||||||
|
idx = findp(j, index, parent_list)
|
||||||
|
|
||||||
|
seq_order.append(idx)
|
||||||
|
positions.append(len(idx) + seq_len)
|
||||||
|
t = torch.full((draft_token - 1,), False)
|
||||||
|
t[idx] = True
|
||||||
|
mask.append(t)
|
||||||
|
|
||||||
|
for i in range(1, draft_token - 1):
|
||||||
|
is_leaf = 0
|
||||||
|
for j in range(draft_token - 1):
|
||||||
|
if i in seq_order[j]:
|
||||||
|
is_leaf += 1
|
||||||
|
|
||||||
|
if is_leaf == 1:
|
||||||
|
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
|
||||||
|
for _ in range(max_depth + 1 - len(seq_order[i])):
|
||||||
|
order_list.append(-1)
|
||||||
|
order = torch.Tensor(order_list).cuda().to(torch.long)
|
||||||
|
r_index.append(order)
|
||||||
|
retrive_index.append(torch.stack(r_index))
|
||||||
|
|
||||||
|
return (
|
||||||
|
torch.cat(mask).cuda(),
|
||||||
|
torch.Tensor(positions).cuda().to(torch.long),
|
||||||
|
torch.stack(retrive_index),
|
||||||
|
)
|
||||||
|
|
||||||
|
index = (
|
||||||
|
torch.Tensor(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
20,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
30,
|
||||||
|
110,
|
||||||
|
130,
|
||||||
|
150,
|
||||||
|
160,
|
||||||
|
210,
|
||||||
|
211,
|
||||||
|
212,
|
||||||
|
213,
|
||||||
|
214,
|
||||||
|
215,
|
||||||
|
216,
|
||||||
|
217,
|
||||||
|
218,
|
||||||
|
219,
|
||||||
|
220,
|
||||||
|
230,
|
||||||
|
310,
|
||||||
|
311,
|
||||||
|
312,
|
||||||
|
313,
|
||||||
|
314,
|
||||||
|
315,
|
||||||
|
316,
|
||||||
|
317,
|
||||||
|
320,
|
||||||
|
321,
|
||||||
|
322,
|
||||||
|
330,
|
||||||
|
360,
|
||||||
|
380,
|
||||||
|
390,
|
||||||
|
410,
|
||||||
|
411,
|
||||||
|
412,
|
||||||
|
413,
|
||||||
|
414,
|
||||||
|
415,
|
||||||
|
416,
|
||||||
|
417,
|
||||||
|
418,
|
||||||
|
419,
|
||||||
|
420,
|
||||||
|
421,
|
||||||
|
422,
|
||||||
|
423,
|
||||||
|
430,
|
||||||
|
431,
|
||||||
|
440,
|
||||||
|
441,
|
||||||
|
460,
|
||||||
|
470,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.to(torch.long)
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_list = (
|
||||||
|
torch.Tensor(
|
||||||
|
[
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
20,
|
||||||
|
30,
|
||||||
|
21,
|
||||||
|
13,
|
||||||
|
22,
|
||||||
|
40,
|
||||||
|
23,
|
||||||
|
110,
|
||||||
|
130,
|
||||||
|
160,
|
||||||
|
150,
|
||||||
|
190,
|
||||||
|
120,
|
||||||
|
111,
|
||||||
|
121,
|
||||||
|
200,
|
||||||
|
180,
|
||||||
|
210,
|
||||||
|
211,
|
||||||
|
212,
|
||||||
|
213,
|
||||||
|
214,
|
||||||
|
215,
|
||||||
|
216,
|
||||||
|
220,
|
||||||
|
230,
|
||||||
|
217,
|
||||||
|
310,
|
||||||
|
311,
|
||||||
|
312,
|
||||||
|
313,
|
||||||
|
320,
|
||||||
|
314,
|
||||||
|
321,
|
||||||
|
315,
|
||||||
|
316,
|
||||||
|
317,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.to(torch.long)
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
|
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
|
||||||
|
bs = verified_seq_len.shape[0]
|
||||||
|
topk = 10
|
||||||
|
depth = 5 # depth <= 10
|
||||||
|
draft_token = 64
|
||||||
|
|
||||||
|
tree_mask = torch.full(
|
||||||
|
(
|
||||||
|
torch.sum(verified_seq_len).item() * draft_token
|
||||||
|
+ draft_token * draft_token * bs,
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
).cuda()
|
||||||
|
retrive_index = torch.full(
|
||||||
|
(bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
||||||
|
)
|
||||||
|
positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
|
||||||
|
|
||||||
|
kernels.build_tree(
|
||||||
|
parent_list.unsqueeze(0),
|
||||||
|
index.unsqueeze(0),
|
||||||
|
verified_seq_len,
|
||||||
|
tree_mask,
|
||||||
|
positions,
|
||||||
|
retrive_index,
|
||||||
|
topk,
|
||||||
|
depth,
|
||||||
|
draft_token,
|
||||||
|
grid=(bs, 1, 1),
|
||||||
|
block=(64, 1, 1),
|
||||||
|
)
|
||||||
|
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
||||||
|
|
||||||
|
c_mask, c_positions, c_retive_index = create_mask(
|
||||||
|
verified_seq_len, draft_token, index, parent_list, depth
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
||||||
|
assert torch.allclose(positions, c_positions), "positions has error."
|
||||||
|
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|
||||||
618
python/sglang/srt/speculative/eagle_utils.py
Normal file
618
python/sglang/srt/speculative/eagle_utils.py
Normal file
@@ -0,0 +1,618 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
|
create_flashinfer_kv_indices_triton,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from python.sglang.srt.layers.sampler import SampleOutput
|
||||||
|
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def eagle_verify_retrive(
|
||||||
|
retrive_index,
|
||||||
|
accept_mask,
|
||||||
|
retrive_cum_len,
|
||||||
|
accept_index,
|
||||||
|
accept_length,
|
||||||
|
extract_index,
|
||||||
|
max_len: tl.constexpr,
|
||||||
|
draft_token_num: tl.constexpr,
|
||||||
|
max_len_upper: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
retrive_end = tl.load(retrive_cum_len + pid + 1)
|
||||||
|
retrive_start = tl.load(retrive_cum_len + pid)
|
||||||
|
retrive_len = retrive_end - retrive_start
|
||||||
|
accept_ptr = accept_mask + retrive_start
|
||||||
|
accept_offset = tl.arange(0, draft_token_num)
|
||||||
|
accept_load_mask = accept_offset < retrive_len
|
||||||
|
accept_len_list = tl.load(
|
||||||
|
accept_ptr + accept_offset, mask=accept_load_mask, other=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
accept_len = tl.max(accept_len_list)
|
||||||
|
max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
|
||||||
|
# triton is not support argmax with tie_break_right, so I need implement it by some way
|
||||||
|
mask_max = accept_len_list == accept_len
|
||||||
|
|
||||||
|
count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
|
||||||
|
count = tl.sum(tl.where(mask_max, 1, count_mask))
|
||||||
|
if count > 1:
|
||||||
|
index = tl.arange(0, draft_token_num)
|
||||||
|
mask_left = index != max_index
|
||||||
|
remained_index = tl.where(mask_max and mask_left, index, 0)
|
||||||
|
max_index = tl.max(remained_index)
|
||||||
|
|
||||||
|
tl.store(accept_length + pid, accept_len)
|
||||||
|
retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
|
||||||
|
retrive_offset = tl.arange(0, max_len_upper)
|
||||||
|
retrive_load_mask = retrive_offset < accept_len + 1
|
||||||
|
data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_load_ptr = accept_index + pid * max_len + accept_len
|
||||||
|
if accept_len == max_len - 1:
|
||||||
|
extract_data = tl.load(extract_load_ptr - 1)
|
||||||
|
tl.store(extract_index + pid * 2, extract_data)
|
||||||
|
extract_data = tl.load(extract_load_ptr)
|
||||||
|
tl.store(extract_index + pid * 2 + 1, extract_data)
|
||||||
|
|
||||||
|
else:
|
||||||
|
extract_data = tl.load(extract_load_ptr)
|
||||||
|
tl.store(extract_index + pid * 2, extract_data)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def create_extend_spec_info(
|
||||||
|
verified_id,
|
||||||
|
seq_len,
|
||||||
|
accept_len,
|
||||||
|
accept_len_cum,
|
||||||
|
positions,
|
||||||
|
new_verified_id,
|
||||||
|
accept_len_upper: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
|
||||||
|
seq_length = tl.load(seq_len + pid)
|
||||||
|
accept_length = tl.load(accept_len + pid)
|
||||||
|
positions_ptr = positions + offset
|
||||||
|
data = tl.arange(0, accept_len_upper)
|
||||||
|
mask = data < accept_length
|
||||||
|
tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
|
||||||
|
|
||||||
|
offset = tl.load(accept_len_cum + pid) - 1
|
||||||
|
verified_id_data = tl.load(verified_id + offset)
|
||||||
|
tl.store(new_verified_id + pid, verified_id_data)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def assign_req_to_token_pool(
|
||||||
|
req_pool_indices,
|
||||||
|
req_to_token,
|
||||||
|
start_offset,
|
||||||
|
end_offset,
|
||||||
|
out_cache_loc,
|
||||||
|
pool_len: tl.constexpr,
|
||||||
|
bs_upper: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 32
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
kv_start = tl.load(start_offset + pid)
|
||||||
|
kv_end = tl.load(end_offset + pid)
|
||||||
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||||
|
|
||||||
|
length_offset = tl.arange(0, bs_upper)
|
||||||
|
start = tl.load(start_offset + length_offset, mask=length_offset < pid)
|
||||||
|
end = tl.load(end_offset + length_offset, mask=length_offset < pid)
|
||||||
|
out_offset = tl.sum(end - start, axis=0)
|
||||||
|
|
||||||
|
out_cache_ptr = out_cache_loc + out_offset
|
||||||
|
|
||||||
|
save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
|
||||||
|
load_offset = tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||||
|
for _ in range(num_loop):
|
||||||
|
mask = save_offset < kv_end
|
||||||
|
data = tl.load(out_cache_ptr + load_offset, mask=mask)
|
||||||
|
tl.store(token_pool + save_offset, data, mask=mask)
|
||||||
|
save_offset += BLOCK_SIZE
|
||||||
|
load_offset += BLOCK_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def generate_draft_decode_kv_indices(
|
||||||
|
req_pool_indices,
|
||||||
|
req_to_token,
|
||||||
|
paged_kernel_lens,
|
||||||
|
kv_indices,
|
||||||
|
iters: tl.constexpr,
|
||||||
|
topk: tl.constexpr,
|
||||||
|
pool_len: tl.constexpr,
|
||||||
|
bs_upper: tl.constexpr,
|
||||||
|
iter_upper: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 128
|
||||||
|
bid = tl.program_id(axis=0)
|
||||||
|
topk_id = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
load_offset = tl.arange(0, bs_upper)
|
||||||
|
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
|
||||||
|
seq_len = tl.load(paged_kernel_lens + bid)
|
||||||
|
cum_seq_len = tl.sum(seq_lens)
|
||||||
|
|
||||||
|
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
||||||
|
kv_ptr = kv_indices + kv_offset
|
||||||
|
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
||||||
|
|
||||||
|
kv_offset = tl.arange(0, BLOCK_SIZE)
|
||||||
|
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
|
||||||
|
for _ in range(num_loop):
|
||||||
|
mask = kv_offset < seq_len
|
||||||
|
data = tl.load(token_pool_ptr + kv_offset, mask=mask)
|
||||||
|
tl.store(kv_ptr + kv_offset, data, mask=mask)
|
||||||
|
kv_offset += BLOCK_SIZE
|
||||||
|
|
||||||
|
extend_offset = tl.arange(0, iter_upper)
|
||||||
|
extend_data = tl.load(
|
||||||
|
token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
|
||||||
|
mask=extend_offset < iters,
|
||||||
|
)
|
||||||
|
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
||||||
|
|
||||||
|
|
||||||
|
class EAGLEDraftInput(SpecInfo):
|
||||||
|
hidden_states: torch.Tensor = None
|
||||||
|
verified_id: torch.Tensor = None
|
||||||
|
positions: torch.Tensor = None
|
||||||
|
accept_length: torch.Tensor = None
|
||||||
|
has_finished: bool = False
|
||||||
|
unfinished_index: List[int] = None
|
||||||
|
|
||||||
|
def init(self, server_args: ServerArgs):
|
||||||
|
self.prev_mode = ForwardMode.DECODE
|
||||||
|
self.sample_output = None
|
||||||
|
self.topk: int = server_args.speculative_eagle_topk
|
||||||
|
self.num_verify_token: int = server_args.speculative_num_draft_tokens
|
||||||
|
self.spec_steps = server_args.speculative_num_steps
|
||||||
|
|
||||||
|
self.scores: torch.Tensor = None
|
||||||
|
self.score_list: List[torch.Tensor] = []
|
||||||
|
self.token_list: List[torch.Tensor] = []
|
||||||
|
self.origin_score_list: List[torch.Tensor] = [] # used for sampling
|
||||||
|
self.parents_list: List[torch.Tensor] = []
|
||||||
|
self.cache_list: List[torch.Tenor] = []
|
||||||
|
self.iter = 0
|
||||||
|
self.root_token: int = None
|
||||||
|
|
||||||
|
assert self.topk <= 10, "topk should <= 10"
|
||||||
|
|
||||||
|
def prepare_for_extend(self, batch: ForwardBatch):
|
||||||
|
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
||||||
|
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
||||||
|
batch.out_cache_loc = out_cache_loc
|
||||||
|
|
||||||
|
pt = 0
|
||||||
|
for i, req in enumerate(batch.reqs):
|
||||||
|
req.req_pool_idx = req_pool_indices[i]
|
||||||
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||||
|
assert seq_len - pre_len == req.extend_input_len
|
||||||
|
|
||||||
|
if pre_len > 0:
|
||||||
|
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
:pre_len
|
||||||
|
] = req.prefix_indices
|
||||||
|
|
||||||
|
batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||||
|
out_cache_loc[pt : pt + req.extend_input_len]
|
||||||
|
)
|
||||||
|
|
||||||
|
pt += req.extend_input_len
|
||||||
|
|
||||||
|
seq_lens = [0] + batch.extend_lens
|
||||||
|
input_ids = batch.input_ids.tolist()
|
||||||
|
verified_id = batch.spec_info.verified_id.tolist()
|
||||||
|
model_input_ids = []
|
||||||
|
for i in range(len(seq_lens) - 1):
|
||||||
|
model_input_ids.extend(
|
||||||
|
input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
|
||||||
|
)
|
||||||
|
batch.input_ids = torch.tensor(
|
||||||
|
model_input_ids, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
def capture_for_decode(
|
||||||
|
self,
|
||||||
|
sample_output: SampleOutput,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
prev_mode: ForwardMode,
|
||||||
|
):
|
||||||
|
self.sample_output = sample_output
|
||||||
|
self.prev_mode = prev_mode
|
||||||
|
self.hidden_states = hidden_states
|
||||||
|
|
||||||
|
def prepare_for_decode(self, batch: ScheduleBatch):
|
||||||
|
prob = self.sample_output # b * (1/topk), vocab
|
||||||
|
top = torch.topk(prob, self.topk, dim=-1)
|
||||||
|
topk_index, topk_p = top.indices, top.values # b * (1/topk), topk
|
||||||
|
if self.prev_mode == ForwardMode.DECODE:
|
||||||
|
scores = torch.mul(
|
||||||
|
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
|
||||||
|
) # (b, topk) mul (b * topk ,topk) -> b, topk, topk
|
||||||
|
topk_cs = torch.topk(
|
||||||
|
scores.flatten(start_dim=1), self.topk, dim=-1
|
||||||
|
) # (b, topk)
|
||||||
|
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
|
||||||
|
self.scores = topk_cs_p
|
||||||
|
|
||||||
|
selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
|
||||||
|
|
||||||
|
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
|
||||||
|
selected_input_index, :
|
||||||
|
]
|
||||||
|
topk_index = topk_index.reshape(-1, self.topk**2)
|
||||||
|
batch.input_ids = torch.gather(
|
||||||
|
topk_index, index=topk_cs_index, dim=1
|
||||||
|
).flatten()
|
||||||
|
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
||||||
|
self.score_list.append(scores) # b, topk, topk
|
||||||
|
self.token_list.append(topk_index) # b, topk*topk
|
||||||
|
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
|
||||||
|
self.parents_list.append(
|
||||||
|
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
|
||||||
|
) # b, topk
|
||||||
|
|
||||||
|
elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND):
|
||||||
|
self.scores = topk_p # b, top_k
|
||||||
|
self.score_list.append(topk_p.unsqueeze(1))
|
||||||
|
self.token_list.append(topk_index)
|
||||||
|
self.origin_score_list.append(topk_p)
|
||||||
|
batch.spec_info.hidden_states = (
|
||||||
|
batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
|
||||||
|
)
|
||||||
|
batch.input_ids = topk_index.flatten()
|
||||||
|
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
|
||||||
|
self.parents_list.append(
|
||||||
|
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(self.scores.shape[0], 1)
|
||||||
|
) # b, topk+1
|
||||||
|
self.cache_list.append(batch.out_cache_loc)
|
||||||
|
self.positions = (
|
||||||
|
batch.seq_lens[:, None]
|
||||||
|
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
|
||||||
|
).flatten()
|
||||||
|
|
||||||
|
bs = batch.seq_lens.numel()
|
||||||
|
assign_req_to_token_pool[(bs,)](
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.seq_lens + self.topk * self.iter,
|
||||||
|
batch.seq_lens + self.topk * (self.iter + 1),
|
||||||
|
batch.out_cache_loc,
|
||||||
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
|
triton.next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
self.iter += 1
|
||||||
|
|
||||||
|
def prepare_extend_after_decode(self, batch: ScheduleBatch):
|
||||||
|
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
||||||
|
batch.extend_lens = (self.accept_length + 1).tolist()
|
||||||
|
|
||||||
|
pt = 0
|
||||||
|
seq_lens = batch.seq_lens.tolist()
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
for req in batch.reqs:
|
||||||
|
if req.finished():
|
||||||
|
continue
|
||||||
|
# assert seq_len - pre_len == req.extend_input_len
|
||||||
|
input_len = self.accept_length[i] + 1
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
seq_len - input_len : seq_len
|
||||||
|
] = batch.out_cache_loc[pt : pt + input_len]
|
||||||
|
pt += input_len
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
self.positions = torch.empty_like(self.verified_id)
|
||||||
|
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
||||||
|
self.accept_length.add_(1)
|
||||||
|
|
||||||
|
create_extend_spec_info[(self.accept_length.numel(),)](
|
||||||
|
self.verified_id,
|
||||||
|
batch.seq_lens,
|
||||||
|
self.accept_length,
|
||||||
|
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
||||||
|
self.positions,
|
||||||
|
new_verified_id,
|
||||||
|
triton.next_power_of_2(self.spec_steps + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
batch.input_ids = self.verified_id
|
||||||
|
self.verified_id = new_verified_id
|
||||||
|
|
||||||
|
def prepare_for_verify(self, batch: ScheduleBatch):
|
||||||
|
score_list = torch.cat(self.score_list, dim=1).flatten(
|
||||||
|
1
|
||||||
|
) # b, n, topk; n= 1+(self.iter-1)*self.topk
|
||||||
|
ss_token_list = torch.cat(
|
||||||
|
self.token_list, dim=1
|
||||||
|
) # b, (self.topk+(self.iter-1)*self.topk)
|
||||||
|
origin_token_list = torch.cat(self.origin_score_list, dim=1)
|
||||||
|
top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
|
||||||
|
top_scores_index = top_scores.indices
|
||||||
|
top_scores_index = torch.sort(top_scores_index).values
|
||||||
|
|
||||||
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
||||||
|
scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
|
||||||
|
draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
|
||||||
|
parent_list = torch.cat(self.parents_list[:-1], dim=1)
|
||||||
|
|
||||||
|
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
|
||||||
|
parent_list,
|
||||||
|
top_scores_index,
|
||||||
|
batch.seq_lens,
|
||||||
|
self.topk,
|
||||||
|
self.iter - 1,
|
||||||
|
self.num_verify_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return EagleVerifyInput(
|
||||||
|
draft_tokens.flatten(),
|
||||||
|
scores.flatten(),
|
||||||
|
tree_mask,
|
||||||
|
position,
|
||||||
|
retrive_index,
|
||||||
|
retrive_cum_len,
|
||||||
|
self.num_verify_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_attn_arg_decode(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
paged_kernel_lens: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
):
|
||||||
|
seq_num = req_pool_indices.numel()
|
||||||
|
bs = self.topk * req_pool_indices.numel()
|
||||||
|
seq_len = self.positions.reshape(-1).contiguous()
|
||||||
|
|
||||||
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0)
|
||||||
|
total_len = torch.sum(paged_kernel_lens).item()
|
||||||
|
|
||||||
|
kv_indices = torch.empty(
|
||||||
|
(total_len * self.topk + seq_num * self.iter * self.topk,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](
|
||||||
|
req_pool_indices,
|
||||||
|
req_to_token,
|
||||||
|
paged_kernel_lens,
|
||||||
|
kv_indices,
|
||||||
|
self.iter,
|
||||||
|
self.topk,
|
||||||
|
req_to_token.shape[1],
|
||||||
|
triton.next_power_of_2(seq_num),
|
||||||
|
triton.next_power_of_2(self.spec_steps),
|
||||||
|
)
|
||||||
|
return bs, kv_indices, cum_kv_seq_len
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.iter = 0
|
||||||
|
self.score_list.clear()
|
||||||
|
self.positions = None
|
||||||
|
|
||||||
|
def clear_draft_cache(self, batch):
|
||||||
|
draft_cache = torch.cat(self.cache_list, dim=0)
|
||||||
|
batch.token_to_kv_pool.free(draft_cache)
|
||||||
|
|
||||||
|
def generate_attn_arg_prefill(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
paged_kernel_lens: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
):
|
||||||
|
bs = self.accept_length.numel()
|
||||||
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||||
|
|
||||||
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
cum_kv_seq_len,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
req_to_token.size(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
||||||
|
|
||||||
|
def merge_batch(self, spec_info: EAGLEDraftInput):
|
||||||
|
|
||||||
|
self.hidden_states = torch.cat(
|
||||||
|
[self.hidden_states, spec_info.hidden_states], axis=0
|
||||||
|
)
|
||||||
|
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
||||||
|
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
|
||||||
|
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
|
||||||
|
|
||||||
|
|
||||||
|
class EagleVerifyInput(SpecInfo):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
draft_token: torch.Tensor,
|
||||||
|
draft_score: torch.Tensor,
|
||||||
|
tree_mask: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
retrive_index: torch.Tensor,
|
||||||
|
retrive_cum_len: torch.Tensor,
|
||||||
|
draft_token_num: int,
|
||||||
|
):
|
||||||
|
self.draft_token = draft_token
|
||||||
|
self.draft_score = draft_score
|
||||||
|
self.custom_mask = tree_mask
|
||||||
|
self.positions = positions
|
||||||
|
self.retrive_index = retrive_index
|
||||||
|
self.retrive_cum_len = retrive_cum_len
|
||||||
|
self.draft_token_num = draft_token_num
|
||||||
|
|
||||||
|
def prepare_for_verify(self, batch: ScheduleBatch):
|
||||||
|
batch.input_ids = self.draft_token
|
||||||
|
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
||||||
|
bs = batch.seq_lens.numel()
|
||||||
|
assign_req_to_token_pool[(bs,)](
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.seq_lens,
|
||||||
|
batch.seq_lens + self.draft_token_num,
|
||||||
|
batch.out_cache_loc,
|
||||||
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
|
triton.next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_attn_arg_prefill(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
paged_kernel_lens: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
):
|
||||||
|
batch_size = len(req_pool_indices)
|
||||||
|
qo_indptr = torch.arange(
|
||||||
|
0,
|
||||||
|
(1 + batch_size) * self.draft_token_num,
|
||||||
|
step=self.draft_token_num,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
cum_kv_seq_len = torch.zeros(
|
||||||
|
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
||||||
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
|
||||||
|
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
cum_kv_seq_len,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
req_to_token.size(1),
|
||||||
|
)
|
||||||
|
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
||||||
|
|
||||||
|
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||||
|
predict = torch.cat(
|
||||||
|
[predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
|
||||||
|
)
|
||||||
|
draft_token = torch.cat(
|
||||||
|
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
target_predict = predict[self.retrive_index]
|
||||||
|
candidates = draft_token[self.retrive_index]
|
||||||
|
# logits = logits_output.next_token_logits[self.retrive_index]
|
||||||
|
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
|
||||||
|
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
|
||||||
|
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
|
||||||
|
bs = self.retrive_cum_len.numel() - 1
|
||||||
|
|
||||||
|
max_draft_len = self.retrive_index.shape[-1]
|
||||||
|
accept_index = torch.full(
|
||||||
|
(bs, max_draft_len), -1, dtype=torch.long, device="cuda"
|
||||||
|
)
|
||||||
|
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
|
||||||
|
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
|
||||||
|
eagle_verify_retrive[(bs,)](
|
||||||
|
self.retrive_index.contiguous(),
|
||||||
|
accept_mask.contiguous(),
|
||||||
|
self.retrive_cum_len,
|
||||||
|
accept_index,
|
||||||
|
accept_length,
|
||||||
|
extract_index,
|
||||||
|
max_draft_len,
|
||||||
|
self.draft_token_num,
|
||||||
|
triton.next_power_of_2(max_draft_len),
|
||||||
|
)
|
||||||
|
|
||||||
|
accept_index = accept_index[accept_index != -1]
|
||||||
|
# extract_index = extract_index[extract_index != 0]
|
||||||
|
|
||||||
|
draft_input = EAGLEDraftInput()
|
||||||
|
|
||||||
|
accept_length_cpu = accept_length.tolist()
|
||||||
|
verified_id = predict[accept_index]
|
||||||
|
verified_id_cpu = verified_id.tolist()
|
||||||
|
|
||||||
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||||
|
evict_mask[accept_index] = False
|
||||||
|
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
||||||
|
batch.token_to_kv_pool.free(mem_need_free_idx)
|
||||||
|
assign_req_to_token_pool[(bs,)](
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.seq_lens,
|
||||||
|
batch.seq_lens + accept_length + 1,
|
||||||
|
batch.out_cache_loc[accept_index],
|
||||||
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
|
triton.next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
batch.seq_lens.add_(accept_length + 1)
|
||||||
|
new_accept_index = []
|
||||||
|
unfinished_index = []
|
||||||
|
finished_extend_len = {} # {rid:accept_length + 1}
|
||||||
|
# retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||||
|
|
||||||
|
low = 0
|
||||||
|
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
|
||||||
|
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
|
||||||
|
req.check_finished()
|
||||||
|
if req.finished():
|
||||||
|
draft_input.has_finished = True
|
||||||
|
finished_extend_len[req.rid] = verified_len + 1
|
||||||
|
else:
|
||||||
|
new_accept_index.append(accept_index[low : low + verified_len + 1])
|
||||||
|
unfinished_index.append(i)
|
||||||
|
low += verified_len + 1
|
||||||
|
|
||||||
|
if len(new_accept_index) > 0:
|
||||||
|
new_accept_index = torch.cat(new_accept_index, dim=0)
|
||||||
|
draft_input.verified_id = predict[new_accept_index]
|
||||||
|
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
||||||
|
draft_input.accept_length = accept_length[unfinished_index]
|
||||||
|
draft_input.unfinished_index = unfinished_index
|
||||||
|
|
||||||
|
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
||||||
|
return draft_input, logits_output, verified_id, finished_extend_len
|
||||||
170
python/sglang/srt/speculative/eagle_worker.py
Normal file
170
python/sglang/srt/speculative/eagle_worker.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
|
||||||
|
|
||||||
|
|
||||||
|
class EAGLEWorker(TpModelWorker):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
gpu_id: int,
|
||||||
|
tp_rank: int,
|
||||||
|
dp_rank: Optional[int],
|
||||||
|
nccl_port: int,
|
||||||
|
target_worker: TpModelWorker,
|
||||||
|
):
|
||||||
|
# Do not capture cuda graph in `super().__init__()`
|
||||||
|
# We will capture it later
|
||||||
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||||
|
server_args.disable_cuda_graph = True
|
||||||
|
super().__init__(
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
server_args=server_args,
|
||||||
|
nccl_port=nccl_port,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
is_draft_worker=True,
|
||||||
|
)
|
||||||
|
self.target_worker = target_worker
|
||||||
|
self.server_args = server_args
|
||||||
|
|
||||||
|
# Share the embedding and lm_head
|
||||||
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
|
self.model_runner.model.set_embed_and_head(embed, head)
|
||||||
|
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||||
|
self.model_runner.init_cuda_graphs()
|
||||||
|
|
||||||
|
def forward_draft_decode(self, batch: ScheduleBatch):
|
||||||
|
batch.spec_info.prepare_for_decode(batch)
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
|
self.capture_for_decode(logits_output, forward_batch)
|
||||||
|
|
||||||
|
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||||
|
self._swap_mem_pool(batch, self.model_runner)
|
||||||
|
batch.spec_info.prepare_for_extend(batch)
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
|
self.capture_for_decode(logits_output, forward_batch)
|
||||||
|
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
||||||
|
|
||||||
|
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||||
|
if batch.forward_mode.is_decode():
|
||||||
|
prev_spec_info = batch.spec_info
|
||||||
|
self._swap_mem_pool(batch, self.model_runner)
|
||||||
|
for i in range(self.server_args.speculative_num_steps):
|
||||||
|
self.forward_draft_decode(batch)
|
||||||
|
batch.spec_info.clear_draft_cache(batch)
|
||||||
|
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
||||||
|
(
|
||||||
|
next_draft_input,
|
||||||
|
logits_output,
|
||||||
|
verified_id,
|
||||||
|
self.finish_extend_len,
|
||||||
|
model_worker_batch,
|
||||||
|
) = self.verify(batch)
|
||||||
|
next_draft_input.init(self.server_args)
|
||||||
|
batch.spec_info = next_draft_input
|
||||||
|
# if it is None, means all requsets are finished
|
||||||
|
if batch.spec_info.verified_id is not None:
|
||||||
|
self.forward_extend_after_decode(batch)
|
||||||
|
batch.spec_info = prev_spec_info
|
||||||
|
return logits_output, verified_id, model_worker_batch, next_draft_input
|
||||||
|
|
||||||
|
else:
|
||||||
|
spec_info = EAGLEDraftInput()
|
||||||
|
spec_info.init(self.server_args)
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
model_worker_batch.spec_info = spec_info
|
||||||
|
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||||
|
model_worker_batch
|
||||||
|
)
|
||||||
|
model_worker_batch.spec_info.verified_id = next_token_ids
|
||||||
|
model_worker_batch.spec_info.hidden_states = logits_output.hidden_states
|
||||||
|
batch.spec_info = spec_info
|
||||||
|
self.forward_draft_extend(batch)
|
||||||
|
batch.spec_info = None
|
||||||
|
return logits_output, next_token_ids, model_worker_batch, spec_info
|
||||||
|
|
||||||
|
def verify(self, batch: ScheduleBatch):
|
||||||
|
verify_input = batch.spec_info.prepare_for_verify(batch)
|
||||||
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
|
verify_input.prepare_for_verify(batch)
|
||||||
|
batch.spec_info = verify_input
|
||||||
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
logits_output, _ = self.target_worker.forward_batch_generation(
|
||||||
|
model_worker_batch, skip_sample=True
|
||||||
|
)
|
||||||
|
verify_input.hidden_states = logits_output.hidden_states
|
||||||
|
res = verify_input.verify(batch, logits_output)
|
||||||
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
|
return res + (model_worker_batch,)
|
||||||
|
|
||||||
|
def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
||||||
|
batch.token_to_kv_pool = runner.token_to_kv_pool
|
||||||
|
batch.req_to_token_pool = runner.req_to_token_pool
|
||||||
|
|
||||||
|
def forward_extend_after_decode(self, batch: ScheduleBatch):
|
||||||
|
self._swap_mem_pool(batch, self.model_runner)
|
||||||
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||||
|
if batch.spec_info.has_finished:
|
||||||
|
index = batch.spec_info.unfinished_index
|
||||||
|
seq_lens = batch.seq_lens
|
||||||
|
batch.seq_lens = batch.seq_lens[index]
|
||||||
|
batch.spec_info.prepare_extend_after_decode(batch)
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
|
batch.spec_info.hidden_states = logits_output.hidden_states
|
||||||
|
self.capture_for_decode(logits_output, forward_batch)
|
||||||
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
|
if batch.spec_info.has_finished:
|
||||||
|
batch.seq_lens = seq_lens
|
||||||
|
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
||||||
|
|
||||||
|
def capture_for_decode(self, logits_output, forward_batch):
|
||||||
|
if isinstance(logits_output, LogitsProcessorOutput):
|
||||||
|
logits = logits_output.next_token_logits
|
||||||
|
sample_output = torch.softmax(
|
||||||
|
logits, dim=-1
|
||||||
|
) # TODO: Support more sampling method @kavioyu
|
||||||
|
forward_batch.spec_info.capture_for_decode(
|
||||||
|
sample_output, logits_output.hidden_states, forward_batch.forward_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Don't support prefix share now.
|
||||||
|
def finish_request(self, reqs: Union[Req, List[Req]]):
|
||||||
|
if not isinstance(reqs, List):
|
||||||
|
reqs = [reqs]
|
||||||
|
for req in reqs:
|
||||||
|
req_len = (
|
||||||
|
len(req.origin_input_ids)
|
||||||
|
+ len(req.output_ids)
|
||||||
|
- self.finish_extend_len[req.rid]
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
kv_indices = self.model_runner.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx
|
||||||
|
][:req_len]
|
||||||
|
self.model_runner.token_to_kv_pool.free(kv_indices)
|
||||||
|
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
||||||
@@ -13,6 +13,7 @@ suites = {
|
|||||||
"test_abort.py",
|
"test_abort.py",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_double_sparsity.py",
|
"test_double_sparsity.py",
|
||||||
|
"test_eagle_infer.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
"test_eval_accuracy_mini.py",
|
"test_eval_accuracy_mini.py",
|
||||||
"test_get_weights_by_name.py",
|
"test_get_weights_by_name.py",
|
||||||
|
|||||||
39
test/srt/test_eagle_infer.py
Normal file
39
test/srt/test_eagle_infer.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
|
||||||
|
class TestEAGLEEngine(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_eagle_accuracy(self):
|
||||||
|
prompt = "Today is a sunny day and I like"
|
||||||
|
target_model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B"
|
||||||
|
|
||||||
|
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
|
engine = sgl.Engine(
|
||||||
|
model_path=target_model_path,
|
||||||
|
speculative_draft_model_path=speculative_draft_model_path,
|
||||||
|
speculative_algorithm="EAGLE",
|
||||||
|
speculative_num_steps=3,
|
||||||
|
speculative_eagle_topk=4,
|
||||||
|
speculative_num_draft_tokens=16,
|
||||||
|
)
|
||||||
|
out1 = engine.generate(prompt, sampling_params)["text"]
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
engine = sgl.Engine(model_path=target_model_path)
|
||||||
|
out2 = engine.generate(prompt, sampling_params)["text"]
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
print("==== Answer 1 ====")
|
||||||
|
print(out1)
|
||||||
|
|
||||||
|
print("==== Answer 2 ====")
|
||||||
|
print(out2)
|
||||||
|
self.assertEqual(out1, out2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user