[Eagle] Remove the greedy branch and some redundant code (#4363)
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -43,7 +43,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.0.5.post1",
|
"sgl-kernel==0.0.5.post2",
|
||||||
"flashinfer_python==0.2.3",
|
"flashinfer_python==0.2.3",
|
||||||
"torch==2.5.1",
|
"torch==2.5.1",
|
||||||
"vllm>=0.6.4.post1,<=0.7.2",
|
"vllm>=0.6.4.post1,<=0.7.2",
|
||||||
|
|||||||
@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/flush_cache")
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
||||||
async def flush_cache():
|
async def flush_cache():
|
||||||
"""Flush the radix cache."""
|
"""Flush the radix cache."""
|
||||||
_global_state.tokenizer_manager.flush_cache()
|
_global_state.tokenizer_manager.flush_cache()
|
||||||
|
|||||||
@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
f"#token: {num_used}, "
|
f"#token: {num_used}, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
||||||
f"largest-len: {self._largest_prefill_decode_len}, "
|
|
||||||
f"#queue-req: {len(self.waiting_queue)}, "
|
f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
)
|
)
|
||||||
spec_accept_length = 0
|
spec_accept_length = 0
|
||||||
@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"accept len: {spec_accept_length:.2f}, "
|
f"accept len: {spec_accept_length:.2f}, "
|
||||||
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
||||||
f"largest-len: {self._largest_prefill_decode_len}, "
|
|
||||||
f"#queue-req: {len(self.waiting_queue)}, "
|
f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
else:
|
else:
|
||||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||||
else:
|
else:
|
||||||
capture_bs = list(range(1, 33))
|
# Since speculative decoding requires more cuda graph memory, we
|
||||||
|
# capture less.
|
||||||
|
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
capture_bs += [i * 8 for i in range(21, 33)]
|
capture_bs += [i * 8 for i in range(21, 33)]
|
||||||
@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||||
# is very small. We add more values here to make sure we capture the maximum bs.
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
||||||
capture_bs = list(
|
capture_bs += [model_runner.req_to_token_pool.size - 1] + [
|
||||||
sorted(
|
model_runner.req_to_token_pool.size
|
||||||
set(
|
]
|
||||||
capture_bs
|
|
||||||
+ [model_runner.req_to_token_pool.size - 1]
|
|
||||||
+ [model_runner.req_to_token_pool.size]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
capture_bs = list(sorted(set(capture_bs)))
|
||||||
capture_bs = [
|
capture_bs = [
|
||||||
bs
|
bs
|
||||||
for bs in capture_bs
|
for bs in capture_bs
|
||||||
@@ -508,7 +505,9 @@ class CudaGraphRunner:
|
|||||||
self.raw_num_token = raw_num_token
|
self.raw_num_token = raw_num_token
|
||||||
self.bs = bs
|
self.bs = bs
|
||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
def replay(
|
||||||
|
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
||||||
|
) -> LogitsProcessorOutput:
|
||||||
if not skip_attn_backend_init:
|
if not skip_attn_backend_init:
|
||||||
self.replay_prepare(forward_batch)
|
self.replay_prepare(forward_batch)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -285,7 +285,6 @@ class ServerArgs:
|
|||||||
if self.speculative_algorithm == "EAGLE":
|
if self.speculative_algorithm == "EAGLE":
|
||||||
if self.max_running_requests is None:
|
if self.max_running_requests is None:
|
||||||
self.max_running_requests = 32
|
self.max_running_requests = 32
|
||||||
self.disable_cuda_graph_padding = True
|
|
||||||
self.disable_overlap_schedule = True
|
self.disable_overlap_schedule = True
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overlap scheduler is disabled because of using "
|
"Overlap scheduler is disabled because of using "
|
||||||
|
|||||||
@@ -3,8 +3,13 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
|
|
||||||
from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
|
if is_cuda_available():
|
||||||
|
from sgl_kernel import (
|
||||||
|
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel_efficient_preprocess(
|
def build_tree_kernel_efficient_preprocess(
|
||||||
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
|
|||||||
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
|
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
|
||||||
top_scores_index = top_scores.indices
|
top_scores_index = top_scores.indices
|
||||||
top_scores_index = torch.sort(top_scores_index).values
|
top_scores_index = torch.sort(top_scores_index).values
|
||||||
|
|
||||||
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
||||||
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
||||||
|
|
||||||
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel(
|
|
||||||
verified_id: torch.Tensor,
|
|
||||||
score_list: List[torch.Tensor],
|
|
||||||
token_list: List[torch.Tensor],
|
|
||||||
parents_list: List[torch.Tensor],
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
seq_lens_sum: int,
|
|
||||||
topk: int,
|
|
||||||
spec_steps: int,
|
|
||||||
num_verify_tokens: int,
|
|
||||||
):
|
|
||||||
parent_list, top_scores_index, draft_tokens = (
|
|
||||||
build_tree_kernel_efficient_preprocess(
|
|
||||||
verified_id,
|
|
||||||
score_list,
|
|
||||||
token_list,
|
|
||||||
parents_list,
|
|
||||||
num_verify_tokens,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
bs = seq_lens.numel()
|
|
||||||
device = seq_lens.device
|
|
||||||
|
|
||||||
tree_mask = torch.full(
|
|
||||||
(
|
|
||||||
seq_lens_sum * num_verify_tokens
|
|
||||||
+ num_verify_tokens * num_verify_tokens * bs,
|
|
||||||
),
|
|
||||||
True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
retrive_index = torch.full(
|
|
||||||
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
sgl_build_tree_kernel(
|
|
||||||
parent_list,
|
|
||||||
top_scores_index,
|
|
||||||
seq_lens.to(torch.int32),
|
|
||||||
tree_mask,
|
|
||||||
positions,
|
|
||||||
retrive_index,
|
|
||||||
topk,
|
|
||||||
spec_steps,
|
|
||||||
num_verify_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
index = retrive_index.sum(dim=-1) != -spec_steps - 2
|
|
||||||
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
|
||||||
retrive_cum_len = torch.zeros(
|
|
||||||
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
retrive_cum_len[1:] = cum_len
|
|
||||||
# TODO: this indexing cause a synchronization, optimize this
|
|
||||||
retrive_index = retrive_index[index]
|
|
||||||
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_tree_kernel():
|
|
||||||
def findp(p_i, index, parent_list):
|
|
||||||
pos = index // 10
|
|
||||||
index_list = index.tolist()
|
|
||||||
parent_list = parent_list.tolist()
|
|
||||||
res = [p_i]
|
|
||||||
while True:
|
|
||||||
p = pos[p_i]
|
|
||||||
if p == 0:
|
|
||||||
break
|
|
||||||
token_idx = parent_list[p]
|
|
||||||
p_i = index_list.index(token_idx)
|
|
||||||
res.append(p_i)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
|
|
||||||
mask = []
|
|
||||||
positions = []
|
|
||||||
retrive_index = []
|
|
||||||
for i, lens in enumerate(seq_len.tolist()):
|
|
||||||
first_mask = torch.full((lens + draft_token,), True)
|
|
||||||
first_mask[-(draft_token - 1) :] = False
|
|
||||||
positions.append(lens)
|
|
||||||
mask.append(first_mask)
|
|
||||||
seq_order = []
|
|
||||||
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
|
|
||||||
r_index = [first_index]
|
|
||||||
for j in range(draft_token - 1):
|
|
||||||
mask.append(torch.full((lens + 1,), True))
|
|
||||||
idx = findp(j, index, parent_list)
|
|
||||||
|
|
||||||
seq_order.append(idx)
|
|
||||||
positions.append(len(idx) + seq_len)
|
|
||||||
t = torch.full((draft_token - 1,), False)
|
|
||||||
t[idx] = True
|
|
||||||
mask.append(t)
|
|
||||||
|
|
||||||
for i in range(1, draft_token - 1):
|
|
||||||
is_leaf = 0
|
|
||||||
for j in range(draft_token - 1):
|
|
||||||
if i in seq_order[j]:
|
|
||||||
is_leaf += 1
|
|
||||||
|
|
||||||
if is_leaf == 1:
|
|
||||||
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
|
|
||||||
for _ in range(max_depth + 1 - len(seq_order[i])):
|
|
||||||
order_list.append(-1)
|
|
||||||
order = torch.Tensor(order_list).cuda().to(torch.long)
|
|
||||||
r_index.append(order)
|
|
||||||
retrive_index.append(torch.stack(r_index))
|
|
||||||
|
|
||||||
return (
|
|
||||||
torch.cat(mask).cuda(),
|
|
||||||
torch.Tensor(positions).cuda().to(torch.long),
|
|
||||||
torch.stack(retrive_index),
|
|
||||||
)
|
|
||||||
|
|
||||||
index = (
|
|
||||||
torch.Tensor(
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
30,
|
|
||||||
110,
|
|
||||||
130,
|
|
||||||
150,
|
|
||||||
160,
|
|
||||||
210,
|
|
||||||
211,
|
|
||||||
212,
|
|
||||||
213,
|
|
||||||
214,
|
|
||||||
215,
|
|
||||||
216,
|
|
||||||
217,
|
|
||||||
218,
|
|
||||||
219,
|
|
||||||
220,
|
|
||||||
230,
|
|
||||||
310,
|
|
||||||
311,
|
|
||||||
312,
|
|
||||||
313,
|
|
||||||
314,
|
|
||||||
315,
|
|
||||||
316,
|
|
||||||
317,
|
|
||||||
320,
|
|
||||||
321,
|
|
||||||
322,
|
|
||||||
330,
|
|
||||||
360,
|
|
||||||
380,
|
|
||||||
390,
|
|
||||||
410,
|
|
||||||
411,
|
|
||||||
412,
|
|
||||||
413,
|
|
||||||
414,
|
|
||||||
415,
|
|
||||||
416,
|
|
||||||
417,
|
|
||||||
418,
|
|
||||||
419,
|
|
||||||
420,
|
|
||||||
421,
|
|
||||||
422,
|
|
||||||
423,
|
|
||||||
430,
|
|
||||||
431,
|
|
||||||
440,
|
|
||||||
441,
|
|
||||||
460,
|
|
||||||
470,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
.to(torch.long)
|
|
||||||
.cuda()
|
|
||||||
)
|
|
||||||
|
|
||||||
parent_list = (
|
|
||||||
torch.Tensor(
|
|
||||||
[
|
|
||||||
-1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7,
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
20,
|
|
||||||
30,
|
|
||||||
21,
|
|
||||||
13,
|
|
||||||
22,
|
|
||||||
40,
|
|
||||||
23,
|
|
||||||
110,
|
|
||||||
130,
|
|
||||||
160,
|
|
||||||
150,
|
|
||||||
190,
|
|
||||||
120,
|
|
||||||
111,
|
|
||||||
121,
|
|
||||||
200,
|
|
||||||
180,
|
|
||||||
210,
|
|
||||||
211,
|
|
||||||
212,
|
|
||||||
213,
|
|
||||||
214,
|
|
||||||
215,
|
|
||||||
216,
|
|
||||||
220,
|
|
||||||
230,
|
|
||||||
217,
|
|
||||||
310,
|
|
||||||
311,
|
|
||||||
312,
|
|
||||||
313,
|
|
||||||
320,
|
|
||||||
314,
|
|
||||||
321,
|
|
||||||
315,
|
|
||||||
316,
|
|
||||||
317,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
.to(torch.long)
|
|
||||||
.cuda()
|
|
||||||
)
|
|
||||||
|
|
||||||
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
|
|
||||||
bs = verified_seq_len.shape[0]
|
|
||||||
topk = 10
|
|
||||||
depth = 5 # depth <= 10
|
|
||||||
num_draft_token = 64
|
|
||||||
|
|
||||||
tree_mask = torch.full(
|
|
||||||
(
|
|
||||||
torch.sum(verified_seq_len).item() * num_draft_token
|
|
||||||
+ num_draft_token * num_draft_token * bs,
|
|
||||||
),
|
|
||||||
True,
|
|
||||||
).cuda()
|
|
||||||
retrive_index = torch.full(
|
|
||||||
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
|
||||||
)
|
|
||||||
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
|
|
||||||
|
|
||||||
sgl_build_tree_kernel(
|
|
||||||
parent_list.unsqueeze(0),
|
|
||||||
index.unsqueeze(0),
|
|
||||||
verified_seq_len,
|
|
||||||
tree_mask,
|
|
||||||
positions,
|
|
||||||
retrive_index,
|
|
||||||
topk,
|
|
||||||
depth,
|
|
||||||
num_draft_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
|
||||||
|
|
||||||
c_mask, c_positions, c_retive_index = create_mask(
|
|
||||||
verified_seq_len, num_draft_token, index, parent_list, depth
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
|
||||||
assert torch.allclose(positions, c_positions), "positions has error."
|
|
||||||
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_tree_kernel_efficient():
|
def test_build_tree_kernel_efficient():
|
||||||
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
|
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
|
||||||
score_list = [
|
score_list = [
|
||||||
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
|
|||||||
depth = 4
|
depth = 4
|
||||||
num_draft_token = 8
|
num_draft_token = 8
|
||||||
|
|
||||||
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
|
||||||
build_tree_kernel(
|
|
||||||
verified_id=verified_id,
|
|
||||||
score_list=score_list,
|
|
||||||
token_list=token_list,
|
|
||||||
parents_list=parents_list,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_sum=torch.sum(seq_lens).item(),
|
|
||||||
topk=topk,
|
|
||||||
spec_steps=depth,
|
|
||||||
num_verify_tokens=num_draft_token,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.utils import first_rank_print
|
|
||||||
|
|
||||||
first_rank_print("=========== build tree kernel ==========")
|
|
||||||
# first_rank_print(f"{tree_mask=}", flush=True)
|
|
||||||
first_rank_print(f"{position=}", flush=True)
|
|
||||||
first_rank_print(f"{retrive_index=}", flush=True)
|
|
||||||
first_rank_print(f"{retrive_cum_len=}", flush=True)
|
|
||||||
first_rank_print(f"{draft_tokens=}", flush=True)
|
|
||||||
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
|
||||||
assert retrive_index.tolist() == [
|
|
||||||
[0, -1, -1, -1, -1, -1],
|
|
||||||
[0, 2, 4, 6, -1, -1],
|
|
||||||
[0, 1, 3, 5, 7, -1],
|
|
||||||
[8, -1, -1, -1, -1, -1],
|
|
||||||
[8, 9, 10, -1, -1, -1],
|
|
||||||
[8, 9, 12, -1, -1, -1],
|
|
||||||
[8, 9, 13, -1, -1, -1],
|
|
||||||
[8, 9, 11, 14, 15, -1],
|
|
||||||
]
|
|
||||||
assert retrive_cum_len.tolist() == [0, 3, 8]
|
|
||||||
assert draft_tokens.tolist() == [
|
|
||||||
29974,
|
|
||||||
29896,
|
|
||||||
29906,
|
|
||||||
29889,
|
|
||||||
29974,
|
|
||||||
29946,
|
|
||||||
29896,
|
|
||||||
29946,
|
|
||||||
13,
|
|
||||||
13,
|
|
||||||
22550,
|
|
||||||
4136,
|
|
||||||
16492,
|
|
||||||
8439,
|
|
||||||
29871,
|
|
||||||
29941,
|
|
||||||
]
|
|
||||||
|
|
||||||
(
|
(
|
||||||
tree_mask,
|
tree_mask,
|
||||||
position,
|
position,
|
||||||
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_build_tree_kernel_efficient()
|
test_build_tree_kernel_efficient()
|
||||||
test_build_tree_kernel()
|
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EAGLEDraftCudaGraphRunner:
|
class EAGLEDraftCudaGraphRunner:
|
||||||
def __init__(self, eagle_worker: EAGLEWorker):
|
def __init__(self, eagle_worker: EAGLEWorker):
|
||||||
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = self.model_runner.tp_size
|
||||||
self.dp_size = model_runner.server_args.dp_size
|
|
||||||
self.topk = model_runner.server_args.speculative_eagle_topk
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
||||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||||
server_args = model_runner.server_args
|
server_args = model_runner.server_args
|
||||||
|
|
||||||
assert self.disable_padding
|
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
self.num_tokens_per_bs = server_args.speculative_eagle_topk
|
self.num_tokens_per_bs = server_args.speculative_eagle_topk
|
||||||
@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
set_global_graph_memory_pool(graph.pool())
|
set_global_graph_memory_pool(graph.pool())
|
||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
|
def _postprocess_output_to_raw_bs(self, out, raw_bs):
|
||||||
|
score_list, token_list, parents_list = out
|
||||||
|
score_list = [x[:raw_bs] for x in score_list]
|
||||||
|
token_list = [x[:raw_bs] for x in token_list]
|
||||||
|
parents_list = [x[:raw_bs] for x in parents_list]
|
||||||
|
return (score_list, token_list, parents_list)
|
||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch):
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
assert forward_batch.out_cache_loc is not None
|
assert forward_batch.out_cache_loc is not None
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(1)
|
self.seq_lens.fill_(1)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
self.positions.zero_()
|
||||||
|
|
||||||
|
num_tokens = bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
|
if bs != raw_bs:
|
||||||
|
forward_batch.batch_size = bs
|
||||||
|
forward_batch.seq_lens = self.seq_lens[:bs]
|
||||||
|
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
||||||
|
forward_batch.positions = self.positions[:num_tokens]
|
||||||
|
|
||||||
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
forward_batch, forward_batch.batch_size
|
forward_batch, bs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
self.graphs[bs].replay()
|
self.graphs[bs].replay()
|
||||||
|
out = self.output_buffers[bs]
|
||||||
|
|
||||||
return self.output_buffers[bs]
|
if bs != raw_bs:
|
||||||
|
out = self._postprocess_output_to_raw_bs(out, raw_bs)
|
||||||
|
forward_batch.batch_size = raw_bs
|
||||||
|
forward_batch.positions = self.positions[:raw_num_token]
|
||||||
|
forward_batch.seq_lens = self.seq_lens[:raw_bs]
|
||||||
|
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -13,18 +13,24 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.speculative.build_eagle_tree import (
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||||
build_tree_kernel,
|
|
||||||
build_tree_kernel_efficient,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
if is_cuda_available():
|
if is_cuda_available():
|
||||||
from sgl_kernel import tree_speculative_sampling_target_only
|
from sgl_kernel import (
|
||||||
|
top_k_renorm_prob,
|
||||||
|
top_p_renorm_prob,
|
||||||
|
tree_speculative_sampling_target_only,
|
||||||
|
verify_tree_greedy,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EagleDraftInput:
|
class EagleDraftInput:
|
||||||
@@ -47,12 +53,9 @@ class EagleDraftInput:
|
|||||||
kv_indptr: torch.Tensor = None
|
kv_indptr: torch.Tensor = None
|
||||||
kv_indices: torch.Tensor = None
|
kv_indices: torch.Tensor = None
|
||||||
|
|
||||||
# indices of unfinished requests during extend-after-decode
|
all_padding_lens: Optional[torch.Tensor] = None
|
||||||
# e.g. [0, 2, 3, 4] if only the 1st request is finished
|
|
||||||
keep_indices: List[int] = None
|
|
||||||
|
|
||||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||||
assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
|
|
||||||
# Prefill only generate 1 token.
|
# Prefill only generate 1 token.
|
||||||
assert len(self.verified_id) == len(batch.seq_lens)
|
assert len(self.verified_id) == len(batch.seq_lens)
|
||||||
|
|
||||||
@@ -64,27 +67,18 @@ class EagleDraftInput:
|
|||||||
)
|
)
|
||||||
pt += extend_len
|
pt += extend_len
|
||||||
|
|
||||||
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
def prepare_extend_after_decode(
|
||||||
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
self,
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
speculative_num_steps: int,
|
||||||
|
):
|
||||||
|
assert len(self.verified_id) == len(batch.out_cache_loc)
|
||||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||||
batch.extend_num_tokens = sum(batch.extend_lens)
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||||
|
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||||
seq_lens_cpu = batch.seq_lens.tolist()
|
seq_lens_cpu = batch.seq_lens.tolist()
|
||||||
assert len(batch.req_pool_indices) == len(batch.reqs)
|
|
||||||
|
|
||||||
pt = 0
|
|
||||||
i = 0
|
|
||||||
self.keep_indices = []
|
|
||||||
for idx, req in enumerate(batch.reqs):
|
|
||||||
if req.finished():
|
|
||||||
continue
|
|
||||||
self.keep_indices.append(idx)
|
|
||||||
# assert seq_len - pre_len == req.extend_input_len
|
|
||||||
input_len = batch.extend_lens[i]
|
|
||||||
seq_len = seq_lens_cpu[i]
|
|
||||||
pt += input_len
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
|
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
|
||||||
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
||||||
@@ -112,10 +106,6 @@ class EagleDraftInput:
|
|||||||
req_to_token: torch.Tensor,
|
req_to_token: torch.Tensor,
|
||||||
):
|
):
|
||||||
bs = self.accept_length.numel()
|
bs = self.accept_length.numel()
|
||||||
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
|
|
||||||
req_pool_indices = req_pool_indices[keep_indices]
|
|
||||||
assert req_pool_indices.shape[0] == bs
|
|
||||||
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
|
|
||||||
|
|
||||||
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||||
@@ -172,7 +162,7 @@ class EagleVerifyOutput:
|
|||||||
# Accepeted token length per sequence in a batch in CPU.
|
# Accepeted token length per sequence in a batch in CPU.
|
||||||
accept_length_per_req_cpu: List[int]
|
accept_length_per_req_cpu: List[int]
|
||||||
# Accepeted indices from logits_output.next_token_logits
|
# Accepeted indices from logits_output.next_token_logits
|
||||||
accepeted_indices_cpu: List[int]
|
accepeted_indices: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -200,67 +190,38 @@ class EagleVerifyInput:
|
|||||||
topk: int,
|
topk: int,
|
||||||
spec_steps: int,
|
spec_steps: int,
|
||||||
num_verify_tokens: int,
|
num_verify_tokens: int,
|
||||||
is_all_greedy: bool,
|
|
||||||
):
|
):
|
||||||
if is_all_greedy:
|
(
|
||||||
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
tree_mask,
|
||||||
build_tree_kernel(
|
position,
|
||||||
verified_id,
|
retrive_index,
|
||||||
score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
retrive_next_token,
|
||||||
token_list,
|
retrive_next_sibling,
|
||||||
parents_list,
|
draft_tokens,
|
||||||
seq_lens,
|
) = build_tree_kernel_efficient(
|
||||||
seq_lens_sum,
|
verified_id,
|
||||||
topk,
|
score_list,
|
||||||
spec_steps,
|
token_list,
|
||||||
num_verify_tokens,
|
parents_list,
|
||||||
)
|
seq_lens,
|
||||||
)
|
seq_lens_sum,
|
||||||
|
topk,
|
||||||
|
spec_steps,
|
||||||
|
num_verify_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
draft_tokens,
|
draft_tokens,
|
||||||
tree_mask,
|
tree_mask,
|
||||||
position,
|
position,
|
||||||
retrive_index,
|
retrive_index,
|
||||||
None,
|
retrive_next_token,
|
||||||
None,
|
retrive_next_sibling,
|
||||||
retrive_cum_len,
|
None,
|
||||||
num_verify_tokens,
|
num_verify_tokens,
|
||||||
spec_steps,
|
spec_steps,
|
||||||
CaptureHiddenMode.FULL,
|
CaptureHiddenMode.FULL,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
(
|
|
||||||
tree_mask,
|
|
||||||
position,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
draft_tokens,
|
|
||||||
) = build_tree_kernel_efficient(
|
|
||||||
verified_id,
|
|
||||||
score_list,
|
|
||||||
token_list,
|
|
||||||
parents_list,
|
|
||||||
seq_lens,
|
|
||||||
seq_lens_sum,
|
|
||||||
topk,
|
|
||||||
spec_steps,
|
|
||||||
num_verify_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
draft_tokens,
|
|
||||||
tree_mask,
|
|
||||||
position,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
None,
|
|
||||||
num_verify_tokens,
|
|
||||||
spec_steps,
|
|
||||||
CaptureHiddenMode.FULL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_for_verify(self, batch: ScheduleBatch):
|
def prepare_for_verify(self, batch: ScheduleBatch):
|
||||||
batch.input_ids = self.draft_token
|
batch.input_ids = self.draft_token
|
||||||
@@ -291,7 +252,6 @@ class EagleVerifyInput:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
cum_kv_seq_len = torch.zeros(
|
cum_kv_seq_len = torch.zeros(
|
||||||
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
@@ -304,7 +264,6 @@ class EagleVerifyInput:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(batch_size,)](
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||||
req_to_token,
|
req_to_token,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -322,65 +281,79 @@ class EagleVerifyInput:
|
|||||||
logits_output: torch.Tensor,
|
logits_output: torch.Tensor,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""WARNING: This API in-place modifies the states of logits_output
|
"""
|
||||||
|
|
||||||
Verify and find accepted tokens based on logits output and batch
|
Verify and find accepted tokens based on logits output and batch
|
||||||
(which contains spec decoding information).
|
(which contains spec decoding information).
|
||||||
|
|
||||||
|
WARNING: This API in-place modifies the states of logits_output
|
||||||
|
|
||||||
This API updates values inside logits_output based on the accepted
|
This API updates values inside logits_output based on the accepted
|
||||||
tokens. I.e., logits_output.next_token_logits only contains
|
tokens. I.e., logits_output.next_token_logits only contains
|
||||||
accepeted token logits.
|
accepeted token logits.
|
||||||
"""
|
"""
|
||||||
draft_token = torch.cat(
|
bs = self.retrive_index.shape[0]
|
||||||
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
|
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
||||||
dim=-1,
|
sampling_info = batch.sampling_info
|
||||||
)
|
|
||||||
candidates = draft_token[self.retrive_index]
|
|
||||||
if batch.sampling_info.is_all_greedy:
|
|
||||||
# temp == 0
|
|
||||||
bs = self.retrive_cum_len.numel() - 1
|
|
||||||
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
|
||||||
predict = torch.cat(
|
|
||||||
[predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1
|
|
||||||
)
|
|
||||||
target_predict = predict[self.retrive_index]
|
|
||||||
# logits = logits_output.next_token_logits[self.retrive_index]
|
|
||||||
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
|
|
||||||
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
|
|
||||||
|
|
||||||
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
|
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
||||||
max_draft_len = self.retrive_index.shape[-1]
|
predict_shape[-1] += 1
|
||||||
accept_index = torch.full(
|
predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
|
||||||
(bs, max_draft_len), -1, dtype=torch.int32, device="cuda"
|
accept_index = torch.full(
|
||||||
|
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
if sampling_info.penalizer_orchestrator.is_required:
|
||||||
|
# This is a relaxed version of penalties for speculative decoding.
|
||||||
|
linear_penalty = torch.zeros(
|
||||||
|
(bs, logits_output.next_token_logits.shape[1]),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
)
|
)
|
||||||
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
|
sampling_info.apply_logits_bias(linear_penalty)
|
||||||
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
|
logits_output.next_token_logits.add_(
|
||||||
eagle_verify_retrive[(bs,)](
|
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
||||||
self.retrive_index.contiguous(),
|
)
|
||||||
accept_mask.contiguous(),
|
|
||||||
self.retrive_cum_len,
|
if batch.sampling_info.is_all_greedy:
|
||||||
accept_index,
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||||
accept_length,
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
||||||
extract_index,
|
|
||||||
max_draft_len,
|
verify_tree_greedy(
|
||||||
self.draft_token_num,
|
predicts=predict, # mutable
|
||||||
triton.next_power_of_2(max_draft_len),
|
accept_index=accept_index, # mutable
|
||||||
|
accept_token_num=accept_length, # mutable
|
||||||
|
candidates=candidates.to(torch.int32),
|
||||||
|
retrive_index=self.retrive_index.to(torch.int32),
|
||||||
|
retrive_next_token=self.retrive_next_token.to(torch.int32),
|
||||||
|
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
|
||||||
|
target_predict=target_predict.to(torch.int32),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# temp > 0
|
# apply temperature and get target probs
|
||||||
bs = self.retrive_index.shape[0]
|
expanded_temperature = torch.repeat_interleave(
|
||||||
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
sampling_info.temperatures, self.draft_token_num, dim=0
|
||||||
predict_shape[-1] += 1
|
) # (bs * draft_token_num, 1)
|
||||||
target_logits = logits_output.next_token_logits[self.retrive_index]
|
|
||||||
predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda")
|
target_probs = F.softmax(
|
||||||
accept_index = torch.full(
|
logits_output.next_token_logits / expanded_temperature, dim=-1
|
||||||
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
) # (bs * draft_token_num, vocab_size)
|
||||||
|
target_probs = top_k_renorm_prob(
|
||||||
|
target_probs,
|
||||||
|
torch.repeat_interleave(
|
||||||
|
sampling_info.top_ks, self.draft_token_num, dim=0
|
||||||
|
),
|
||||||
|
) # (bs * draft_token_num, vocab_size)
|
||||||
|
target_probs = top_p_renorm_prob(
|
||||||
|
target_probs,
|
||||||
|
torch.repeat_interleave(
|
||||||
|
sampling_info.top_ps, self.draft_token_num, dim=0
|
||||||
|
),
|
||||||
)
|
)
|
||||||
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
||||||
expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1)
|
|
||||||
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
|
draft_probs = torch.zeros(
|
||||||
draft_probs = torch.full_like(
|
target_probs.shape, dtype=torch.float32, device="cuda"
|
||||||
target_probs, 0, dtype=torch.float32, device="cuda"
|
|
||||||
)
|
)
|
||||||
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
||||||
tree_speculative_sampling_target_only(
|
tree_speculative_sampling_target_only(
|
||||||
@@ -394,6 +367,12 @@ class EagleVerifyInput:
|
|||||||
uniform_samples=coins,
|
uniform_samples=coins,
|
||||||
target_probs=target_probs,
|
target_probs=target_probs,
|
||||||
draft_probs=draft_probs,
|
draft_probs=draft_probs,
|
||||||
|
threshold_single=global_server_args_dict[
|
||||||
|
"speculative_accept_threshold_single"
|
||||||
|
],
|
||||||
|
threshold_acc=global_server_args_dict[
|
||||||
|
"speculative_accept_threshold_acc"
|
||||||
|
],
|
||||||
deterministic=True,
|
deterministic=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -425,119 +404,94 @@ class EagleVerifyInput:
|
|||||||
new_accept_index.extend(new_accept_index_)
|
new_accept_index.extend(new_accept_index_)
|
||||||
unfinished_index.append(i)
|
unfinished_index.append(i)
|
||||||
req.spec_verify_ct += 1
|
req.spec_verify_ct += 1
|
||||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
|
||||||
|
|
||||||
accept_index = accept_index[accept_index != -1]
|
if not has_finished:
|
||||||
accept_length_cpu = accept_length.tolist()
|
accept_index = accept_index[accept_index != -1]
|
||||||
verified_id = predict[accept_index]
|
verified_id = predict[accept_index]
|
||||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||||
evict_mask[accept_index] = False
|
evict_mask[accept_index] = False
|
||||||
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
||||||
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
||||||
assign_req_to_token_pool[(bs,)](
|
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
||||||
batch.req_pool_indices,
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_pool_indices,
|
||||||
batch.seq_lens,
|
batch.req_to_token_pool.req_to_token,
|
||||||
batch.seq_lens + accept_length + 1,
|
batch.seq_lens,
|
||||||
batch.out_cache_loc[accept_index],
|
batch.seq_lens + accept_length + 1,
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
batch.out_cache_loc,
|
||||||
triton.next_power_of_2(bs),
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
)
|
triton.next_power_of_2(bs),
|
||||||
batch.seq_lens.add_(accept_length + 1)
|
)
|
||||||
|
batch.seq_lens.add_(accept_length + 1)
|
||||||
|
accept_length_cpu = accept_length.tolist()
|
||||||
|
|
||||||
draft_input = EagleDraftInput()
|
draft_input = EagleDraftInput()
|
||||||
if len(new_accept_index) > 0:
|
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
||||||
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
draft_input.verified_id = verified_id
|
||||||
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
draft_input.accept_length = accept_length
|
||||||
draft_input.verified_id = predict[new_accept_index]
|
draft_input.accept_length_cpu = accept_length_cpu
|
||||||
draft_input.accept_length = accept_length[unfinished_index]
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||||
draft_input.accept_length_cpu = [
|
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
||||||
accept_length_cpu[i] for i in unfinished_index
|
|
||||||
]
|
|
||||||
if has_finished:
|
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
|
||||||
else:
|
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
|
||||||
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
|
||||||
|
|
||||||
return EagleVerifyOutput(
|
return EagleVerifyOutput(
|
||||||
draft_input=draft_input,
|
draft_input=draft_input,
|
||||||
logits_output=logits_output,
|
logits_output=logits_output,
|
||||||
verified_id=verified_id,
|
verified_id=verified_id,
|
||||||
accept_length_per_req_cpu=accept_length_cpu,
|
accept_length_per_req_cpu=accept_length_cpu,
|
||||||
accepeted_indices_cpu=accept_index,
|
accepeted_indices=accept_index,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||||
|
accept_index = accept_index[accept_index != -1]
|
||||||
|
verified_id = predict[accept_index]
|
||||||
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||||
|
evict_mask[accept_index] = False
|
||||||
|
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
||||||
|
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
||||||
|
assign_req_to_token_pool[(bs,)](
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.seq_lens,
|
||||||
|
batch.seq_lens + accept_length + 1,
|
||||||
|
batch.out_cache_loc[accept_index],
|
||||||
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
|
triton.next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
batch.seq_lens.add_(accept_length + 1)
|
||||||
|
accept_length_cpu = accept_length.tolist()
|
||||||
|
|
||||||
|
draft_input = EagleDraftInput()
|
||||||
|
if len(new_accept_index) > 0:
|
||||||
|
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
||||||
|
draft_input.hidden_states = batch.spec_info.hidden_states[
|
||||||
|
new_accept_index
|
||||||
|
]
|
||||||
|
draft_input.verified_id = predict[new_accept_index]
|
||||||
|
draft_input.accept_length = accept_length[unfinished_index]
|
||||||
|
draft_input.accept_length_cpu = [
|
||||||
|
accept_length_cpu[i] for i in unfinished_index
|
||||||
|
]
|
||||||
|
if has_finished:
|
||||||
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
||||||
|
unfinished_index
|
||||||
|
]
|
||||||
|
draft_input.req_pool_indices_for_draft_extend = (
|
||||||
|
batch.req_pool_indices[unfinished_index]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||||
|
draft_input.req_pool_indices_for_draft_extend = (
|
||||||
|
batch.req_pool_indices
|
||||||
|
)
|
||||||
|
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
||||||
|
|
||||||
@triton.jit
|
return EagleVerifyOutput(
|
||||||
def eagle_verify_retrive(
|
draft_input=draft_input,
|
||||||
retrive_index,
|
logits_output=logits_output,
|
||||||
accept_mask,
|
verified_id=verified_id,
|
||||||
retrive_cum_len,
|
accept_length_per_req_cpu=accept_length_cpu,
|
||||||
accept_index,
|
accepeted_indices=accept_index,
|
||||||
accept_length,
|
)
|
||||||
extract_index,
|
|
||||||
max_len: tl.constexpr,
|
|
||||||
draft_token_num: tl.constexpr,
|
|
||||||
max_len_upper: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
retrive_index: Pointer to indices of draft tokens
|
|
||||||
accept_mask: Mask indicating which tokens were accepted
|
|
||||||
retrive_cum_len: Cumulative lengths of token sequences in a batch
|
|
||||||
accept_index (out): Accept token indices
|
|
||||||
accept_length (out): Length of accepted tokens per sequence in a batch
|
|
||||||
extract_index (out): Index for last accepted tokens
|
|
||||||
max_len: Maximum length in a batch
|
|
||||||
draft_token_num: Number of tokens speculatively generated
|
|
||||||
max_len_upper An upper bound for token sequence length
|
|
||||||
"""
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
|
|
||||||
retrive_end = tl.load(retrive_cum_len + pid + 1)
|
|
||||||
retrive_start = tl.load(retrive_cum_len + pid)
|
|
||||||
retrive_len = retrive_end - retrive_start
|
|
||||||
accept_ptr = accept_mask + retrive_start
|
|
||||||
accept_offset = tl.arange(0, draft_token_num)
|
|
||||||
accept_load_mask = accept_offset < retrive_len
|
|
||||||
accept_len_list = tl.load(
|
|
||||||
accept_ptr + accept_offset, mask=accept_load_mask, other=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
accept_len = tl.max(accept_len_list)
|
|
||||||
max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
|
|
||||||
# triton is not support argmax with tie_break_right, so I need implement it by some way
|
|
||||||
mask_max = accept_len_list == accept_len
|
|
||||||
|
|
||||||
count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
|
|
||||||
count = tl.sum(tl.where(mask_max, 1, count_mask))
|
|
||||||
if count > 1:
|
|
||||||
index = tl.arange(0, draft_token_num)
|
|
||||||
mask_left = index != max_index
|
|
||||||
remained_index = tl.where(mask_max and mask_left, index, 0)
|
|
||||||
max_index = tl.max(remained_index)
|
|
||||||
|
|
||||||
tl.store(accept_length + pid, accept_len)
|
|
||||||
retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
|
|
||||||
retrive_offset = tl.arange(0, max_len_upper)
|
|
||||||
retrive_load_mask = retrive_offset < accept_len + 1
|
|
||||||
data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
extract_load_ptr = accept_index + pid * max_len + accept_len
|
|
||||||
if accept_len == max_len - 1:
|
|
||||||
extract_data = tl.load(extract_load_ptr - 1)
|
|
||||||
tl.store(extract_index + pid * 2, extract_data)
|
|
||||||
extract_data = tl.load(extract_load_ptr)
|
|
||||||
tl.store(extract_index + pid * 2 + 1, extract_data)
|
|
||||||
|
|
||||||
else:
|
|
||||||
extract_data = tl.load(extract_load_ptr)
|
|
||||||
tl.store(extract_index + pid * 2, extract_data)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||||
|
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
|
|||||||
fast_topk,
|
fast_topk,
|
||||||
select_top_k_tokens,
|
select_top_k_tokens,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import get_available_gpu_memory
|
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
|
||||||
|
|
||||||
|
if is_cuda_available():
|
||||||
|
from sgl_kernel import segment_packbits
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def draft_tp_context(tp_group: GroupCoordinator):
|
||||||
|
# Draft model doesn't use dp and has its own tp group.
|
||||||
|
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||||
|
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
class EAGLEWorker(TpModelWorker):
|
class EAGLEWorker(TpModelWorker):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.hot_token_id = None
|
self.hot_token_id = None
|
||||||
|
|
||||||
# Init draft worker
|
# Init draft worker
|
||||||
super().__init__(
|
with empty_context():
|
||||||
gpu_id=gpu_id,
|
super().__init__(
|
||||||
tp_rank=tp_rank,
|
gpu_id=gpu_id,
|
||||||
server_args=server_args,
|
tp_rank=tp_rank,
|
||||||
nccl_port=nccl_port,
|
server_args=server_args,
|
||||||
dp_rank=dp_rank,
|
nccl_port=nccl_port,
|
||||||
is_draft_worker=True,
|
dp_rank=dp_rank,
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
is_draft_worker=True,
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
)
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
# Share the embedding and lm_head
|
# Share the embedding and lm_head
|
||||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||||
head.data = head.data[self.hot_token_id]
|
head.data = head.data[self.hot_token_id]
|
||||||
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||||
|
|
||||||
|
# Init attention backend and cuda graphs
|
||||||
self.draft_model_runner.server_args.disable_cuda_graph = (
|
self.draft_model_runner.server_args.disable_cuda_graph = (
|
||||||
backup_disable_cuda_graph
|
backup_disable_cuda_graph
|
||||||
)
|
)
|
||||||
|
self.draft_tp_context = (
|
||||||
self.init_attention_backend()
|
draft_tp_context if server_args.enable_dp_attention else empty_context
|
||||||
self.init_cuda_graphs()
|
)
|
||||||
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
|
self.init_attention_backend()
|
||||||
|
self.init_cuda_graphs()
|
||||||
|
|
||||||
def init_attention_backend(self):
|
def init_attention_backend(self):
|
||||||
# Create multi-step attn backends and cuda graph runners
|
# Create multi-step attn backends and cuda graph runners
|
||||||
@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||||
self.model_runner,
|
self.draft_model_runner,
|
||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
|
self.draft_extend_attn_backend = None
|
||||||
|
self.padded_static_len = self.speculative_num_steps + 1
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
elif self.server_args.attention_backend == "triton":
|
elif self.server_args.attention_backend == "triton":
|
||||||
from sglang.srt.layers.attention.triton_backend import (
|
from sglang.srt.layers.attention.triton_backend import (
|
||||||
TritonMultiStepDraftBackend,
|
TritonMultiStepDraftBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.draft_attn_backend = TritonMultiStepDraftBackend(
|
self.draft_attn_backend = TritonMultiStepDraftBackend(
|
||||||
self.model_runner,
|
self.draft_model_runner,
|
||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
|
self.draft_extend_attn_backend = None
|
||||||
|
self.padded_static_len = self.speculative_num_steps + 1
|
||||||
|
self.has_prefill_wrapper_verify = False
|
||||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||||
FlashInferMLAMultiStepDraftBackend,
|
FlashInferMLAMultiStepDraftBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
||||||
self.model_runner,
|
self.draft_model_runner,
|
||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
|
self.draft_extend_attn_backend = None
|
||||||
|
self.padded_static_len = self.speculative_num_steps + 1
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
"""Capture cuda graphs."""
|
"""Capture cuda graphs."""
|
||||||
self.cuda_graph_runner = None
|
self.cuda_graph_runner = None
|
||||||
|
self.cuda_graph_runner_for_draft_extend = None
|
||||||
|
|
||||||
if self.server_args.disable_cuda_graph:
|
if self.server_args.disable_cuda_graph:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Capture draft
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||||
)
|
)
|
||||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||||
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Capture extend
|
||||||
|
if self.draft_extend_attn_backend:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def draft_model_runner(self):
|
def draft_model_runner(self):
|
||||||
return self.model_runner
|
return self.model_runner
|
||||||
@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
||||||
"""Run speculative decoding forward.
|
"""Run speculative decoding forward.
|
||||||
|
|
||||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
||||||
the final output batch doesn't have the same state as the input.
|
the final output batch have the same state as the input.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: The batch to run forward. The state of the batch is modified as it runs.
|
batch: The batch to run forward. The state of the batch is modified as it runs.
|
||||||
@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
A tuple of the final logit output of the target model, next tokens accepeted,
|
A tuple of the final logit output of the target model, next tokens accepeted,
|
||||||
the batch id (used for overlap schedule), and number of accepeted tokens.
|
the batch id (used for overlap schedule), and number of accepeted tokens.
|
||||||
"""
|
"""
|
||||||
assert not batch.spec_algorithm.is_none()
|
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
spec_info, to_free_cache_loc = self.draft(batch)
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
|
spec_info, to_free_cache_loc = self.draft(batch)
|
||||||
logits_output, verify_output, model_worker_batch = self.verify(
|
logits_output, verify_output, model_worker_batch = self.verify(
|
||||||
batch, spec_info
|
batch, spec_info
|
||||||
)
|
)
|
||||||
|
|
||||||
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
||||||
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
||||||
# if it is None, means all requests are finished
|
|
||||||
if batch.spec_info.verified_id is not None:
|
|
||||||
self.forward_draft_extend_after_decode(batch)
|
|
||||||
|
|
||||||
|
# If it is None, it means all requests are finished
|
||||||
|
if batch.spec_info.verified_id is not None:
|
||||||
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
|
self.forward_draft_extend_after_decode(batch)
|
||||||
return (
|
return (
|
||||||
logits_output,
|
logits_output,
|
||||||
verify_output.verified_id,
|
verify_output.verified_id,
|
||||||
model_worker_batch.bid,
|
model_worker_batch.bid,
|
||||||
sum(verify_output.accept_length_per_req_cpu),
|
sum(verify_output.accept_length_per_req_cpu),
|
||||||
)
|
)
|
||||||
|
elif batch.forward_mode.is_idle():
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
logits_output, next_token_ids, _ = (
|
||||||
|
self.target_worker.forward_batch_generation(
|
||||||
|
ForwardBatch.init_new(
|
||||||
|
model_worker_batch, self.target_worker.model_runner
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
||||||
else:
|
else:
|
||||||
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||||
self.forward_draft_extend(
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
batch, logits_output.hidden_states, next_token_ids
|
self.forward_draft_extend(
|
||||||
)
|
batch, logits_output.hidden_states, next_token_ids
|
||||||
|
)
|
||||||
return logits_output, next_token_ids, bid, 0
|
return logits_output, next_token_ids, bid, 0
|
||||||
|
|
||||||
def forward_target_extend(
|
def forward_target_extend(
|
||||||
@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
num_seqs = batch.batch_size()
|
num_seqs = batch.batch_size()
|
||||||
spec_info = batch.spec_info
|
spec_info = batch.spec_info
|
||||||
|
|
||||||
|
# Accumulate penalty
|
||||||
|
if batch.sampling_info.penalizer_orchestrator.is_required:
|
||||||
|
# This is a relaxed version of penalties for speculative decoding.
|
||||||
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
|
spec_info.verified_id.to(torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
# Allocate cache locations
|
# Allocate cache locations
|
||||||
out_cache_loc = batch.alloc_token_slots(
|
out_cache_loc = batch.alloc_token_slots(
|
||||||
num_seqs * self.topk * self.speculative_num_steps
|
num_seqs * self.topk * self.speculative_num_steps
|
||||||
@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
self.server_args.speculative_num_draft_tokens,
|
self.server_args.speculative_num_draft_tokens,
|
||||||
batch.sampling_info.is_all_greedy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret, out_cache_loc
|
return ret, out_cache_loc
|
||||||
|
|
||||||
def draft_forward(self, forward_batch: ForwardBatch):
|
def draft_forward(self, forward_batch: ForwardBatch):
|
||||||
@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
token_list.append(tree_info[1])
|
token_list.append(tree_info[1])
|
||||||
parents_list.append(tree_info[2])
|
parents_list.append(tree_info[2])
|
||||||
|
|
||||||
# we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
|
# We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
|
||||||
if i == self.speculative_num_steps - 1:
|
if i == self.speculative_num_steps - 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
spec_info.hidden_states = hidden_states
|
spec_info.hidden_states = hidden_states
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output = self.model_runner.model.forward(
|
logits_output = self.draft_model_runner.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
)
|
)
|
||||||
self._detect_nan_if_needed(logits_output)
|
self._detect_nan_if_needed(logits_output)
|
||||||
@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# Post process based on verified outputs.
|
# Post process based on verified outputs.
|
||||||
# Pick indices that we care (accepeted)
|
# Pick indices that we care (accepeted)
|
||||||
logits_output.next_token_logits = logits_output.next_token_logits[
|
logits_output.next_token_logits = logits_output.next_token_logits[
|
||||||
res.accepeted_indices_cpu
|
res.accepeted_indices
|
||||||
]
|
|
||||||
logits_output.hidden_states = logits_output.hidden_states[
|
|
||||||
res.accepeted_indices_cpu
|
|
||||||
]
|
]
|
||||||
|
logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
|
||||||
|
|
||||||
# Prepare the batch for the next draft forwards.
|
# Prepare the batch for the next draft forwards.
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
batch.spec_info = res.draft_input
|
batch.spec_info = res.draft_input
|
||||||
@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch_next_token_ids,
|
batch_next_token_ids,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add output logprobs to the request.
|
# Add output logprobs to the request
|
||||||
pt = 0
|
pt = 0
|
||||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||||
verified_ids = batch_next_token_ids.tolist()
|
verified_ids = batch_next_token_ids.tolist()
|
||||||
@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
|
|
||||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||||
seq_lens_backup = batch.seq_lens
|
# Backup fileds that will be modified in-place
|
||||||
|
seq_lens_backup = batch.seq_lens.clone()
|
||||||
|
req_pool_indices_backup = batch.req_pool_indices
|
||||||
|
accept_length_backup = batch.spec_info.accept_length
|
||||||
|
return_logprob_backup = batch.return_logprob
|
||||||
|
|
||||||
|
# Prepare metadata
|
||||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||||
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
batch.spec_info.prepare_extend_after_decode(
|
||||||
|
batch,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
)
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
# We don't need logprob for this extend.
|
|
||||||
original_return_logprob = batch.return_logprob
|
|
||||||
batch.return_logprob = False
|
batch.return_logprob = False
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Run
|
||||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||||
|
|
||||||
self._detect_nan_if_needed(logits_output)
|
self._detect_nan_if_needed(logits_output)
|
||||||
assert forward_batch.spec_info is batch.spec_info
|
|
||||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
|
|
||||||
# Restore backup.
|
# Restore backup.
|
||||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||||
batch.return_logprob = original_return_logprob
|
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
batch.seq_lens = seq_lens_backup
|
batch.seq_lens = seq_lens_backup
|
||||||
|
batch.req_pool_indices = req_pool_indices_backup
|
||||||
|
batch.spec_info.accept_length = accept_length_backup
|
||||||
|
batch.return_logprob = return_logprob_backup
|
||||||
|
|
||||||
def capture_for_decode(
|
def capture_for_decode(
|
||||||
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
||||||
@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
if self.enable_nan_detection:
|
if self.enable_nan_detection:
|
||||||
logits = logits_output.next_token_logits
|
logits = logits_output.next_token_logits
|
||||||
if torch.any(torch.isnan(logits)):
|
if torch.any(torch.isnan(logits)):
|
||||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
logger.error("Detected errors during sampling! NaN in the logits.")
|
||||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from contextlib import contextmanager
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int):
|
|||||||
setattr(triton, "next_power_of_2", next_power_of_2)
|
setattr(triton, "next_power_of_2", next_power_of_2)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def empty_context(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
# Setup code goes here
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Cleanup code goes here
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def add_prefix(name: str, prefix: str) -> str:
|
def add_prefix(name: str, prefix: str) -> str:
|
||||||
"""Add a weight path prefix to a module name.
|
"""Add a weight path prefix to a module name.
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
|
|||||||
|
|
||||||
# For compling xgrammar kernels
|
# For compling xgrammar kernels
|
||||||
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
||||||
|
|
||||||
# reinstall sgl-kernel
|
|
||||||
pip install sgl-kernel==0.0.5.post1 --force-reinstall --no-deps
|
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ template <
|
|||||||
typename DType,
|
typename DType,
|
||||||
typename IdType>
|
typename IdType>
|
||||||
__global__ void TreeSpeculativeSamplingTargetOnly(
|
__global__ void TreeSpeculativeSamplingTargetOnly(
|
||||||
IdType* predicts,
|
IdType* predicts, // mutable
|
||||||
IdType* accept_index,
|
IdType* accept_index, // mutable
|
||||||
IdType* accept_token_num, // mutable
|
IdType* accept_token_num, // mutable
|
||||||
IdType* candidates,
|
IdType* candidates,
|
||||||
IdType* retrive_index,
|
IdType* retrive_index,
|
||||||
@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
|||||||
|
|
||||||
template <typename DType, typename IdType>
|
template <typename DType, typename IdType>
|
||||||
cudaError_t TreeSpeculativeSamplingTargetOnly(
|
cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||||
IdType* predicts,
|
IdType* predicts, // mutable
|
||||||
IdType* output_token_ids,
|
IdType* output_token_ids, // mutable
|
||||||
IdType* output_accepted_token_num, // mutable
|
IdType* output_accepted_token_num, // mutable
|
||||||
IdType* candidates,
|
IdType* candidates,
|
||||||
IdType* retrive_index,
|
IdType* retrive_index,
|
||||||
|
|||||||
@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
|
|
||||||
def _test_acc_length(self, engine):
|
def _test_acc_length(self, engine):
|
||||||
prompt = [
|
prompt = [
|
||||||
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
|
||||||
] * 5
|
] * 5 # test batched generation
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
||||||
output = engine.generate(prompt, sampling_params)
|
output = engine.generate(prompt, sampling_params)
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
|
|||||||
"--enable-torch-compile",
|
"--enable-torch-compile",
|
||||||
"--disable-cuda-graph",
|
"--disable-cuda-graph",
|
||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
"2",
|
"4",
|
||||||
"--enable-flashinfer-mla",
|
"--enable-flashinfer-mla",
|
||||||
"--flashinfer-mla-disable-ragged",
|
"--flashinfer-mla-disable-ragged",
|
||||||
]
|
]
|
||||||
@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase):
|
|||||||
other_args.extend(
|
other_args.extend(
|
||||||
[
|
[
|
||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
"2",
|
"4",
|
||||||
"--disable-radix",
|
"--disable-radix",
|
||||||
"--enable-torch-compile",
|
"--enable-torch-compile",
|
||||||
"--torch-compile-max-bs",
|
"--torch-compile-max-bs",
|
||||||
|
|||||||
Reference in New Issue
Block a user