From 56222658ecf7828f0cbacebbfbe1764142270858 Mon Sep 17 00:00:00 2001 From: yinghui <32845984+cicirori@users.noreply.github.com> Date: Tue, 14 Oct 2025 16:50:53 +0200 Subject: [PATCH] move eagle draft post process to cuda graph (#11434) Co-authored-by: Lianmin Zheng --- .../srt/speculative/build_eagle_tree.py | 427 ------------------ .../eagle_draft_cuda_graph_runner.py | 8 +- .../sglang/srt/speculative/eagle_info_v2.py | 108 +---- python/sglang/srt/speculative/eagle_utils.py | 138 ++++++ python/sglang/srt/speculative/eagle_worker.py | 23 +- .../sglang/srt/speculative/eagle_worker_v2.py | 5 +- test/srt/run_suite.py | 1 + test/srt/test_build_eagle_tree.py | 308 +++++++++++++ 8 files changed, 469 insertions(+), 549 deletions(-) delete mode 100644 python/sglang/srt/speculative/build_eagle_tree.py create mode 100644 python/sglang/srt/speculative/eagle_utils.py create mode 100644 test/srt/test_build_eagle_tree.py diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py deleted file mode 100644 index fd27f414c..000000000 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ /dev/null @@ -1,427 +0,0 @@ -# NOTE: Please run this file to make sure the test cases are correct. - -import math -from enum import IntEnum -from typing import List, Optional - -import torch - -from sglang.srt.utils import is_cuda, is_hip - -if is_cuda() or is_hip(): - from sgl_kernel import ( - build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, - ) - - -def build_tree_kernel_efficient_preprocess( - verified_id: torch.Tensor, - score_list: List[torch.Tensor], - token_list: List[torch.Tensor], - parents_list: List[torch.Tensor], - num_verify_tokens: int, -): - score_list = torch.cat(score_list, dim=1).flatten( - 1 - ) # b, n, topk; n= 1 + (num_steps-1) * self.topk - ss_token_list = torch.cat( - token_list, dim=1 - ) # b, (self.topk + (num_steps-1) * self.topk) - top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1) - top_scores_index = top_scores.indices - top_scores_index = torch.sort(top_scores_index).values - draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) - draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() - - if len(parents_list) > 1: - parent_list = torch.cat(parents_list[:-1], dim=1) - else: - batch_size = parents_list[0].shape[0] - parent_list = torch.empty(batch_size, 0, device=parents_list[0].device) - - return parent_list, top_scores_index, draft_tokens - - -class TreeMaskMode(IntEnum): - FULL_MASK = 0 - QLEN_ONLY = 1 - QLEN_ONLY_BITPACKING = 2 - - -def build_tree_kernel_efficient( - verified_id: torch.Tensor, - score_list: List[torch.Tensor], - token_list: List[torch.Tensor], - parents_list: List[torch.Tensor], - seq_lens: torch.Tensor, - seq_lens_sum: int, - topk: int, - spec_steps: int, - num_verify_tokens: int, - tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK, - tree_mask_buf: Optional[torch.Tensor] = None, - position_buf: Optional[torch.Tensor] = None, -): - parent_list, top_scores_index, draft_tokens = ( - build_tree_kernel_efficient_preprocess( - verified_id, - score_list, - token_list, - parents_list, - num_verify_tokens, - ) - ) - - # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens - bs = seq_lens.numel() - device = seq_lens.device - # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened) - # where each row indicates the attending pattern of each draft token - # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed) - if tree_mask_buf is not None: - tree_mask = tree_mask_buf - elif tree_mask_mode == TreeMaskMode.QLEN_ONLY: - tree_mask = torch.full( - (num_verify_tokens * bs * num_verify_tokens,), - True, - dtype=torch.bool, - device=device, - ) - elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: - packed_dtypes = [torch.uint8, torch.uint16, torch.uint32] - packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8))) - tree_mask = torch.zeros( - (num_verify_tokens * bs,), - dtype=packed_dtypes[packed_dtype_idx], - device=device, - ) - elif tree_mask_mode == TreeMaskMode.FULL_MASK: - tree_mask = torch.full( - ( - seq_lens_sum * num_verify_tokens - + num_verify_tokens * num_verify_tokens * bs, - ), - True, - device=device, - ) - else: - raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") - - # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel` - retrive_index = torch.full( - (bs, num_verify_tokens), -1, device=device, dtype=torch.long - ) - retrive_next_token = torch.full( - (bs, num_verify_tokens), -1, device=device, dtype=torch.long - ) - retrive_next_sibling = torch.full( - (bs, num_verify_tokens), -1, device=device, dtype=torch.long - ) - # position: where each token belongs to - # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7 - # then, positions = [7, 8, 8, 9] - if position_buf is not None: - positions = position_buf - else: - positions = torch.empty( - (bs * num_verify_tokens,), device=device, dtype=torch.long - ) - - sgl_build_tree_kernel_efficient( - parent_list, - top_scores_index, - seq_lens, - tree_mask, - positions, - retrive_index, - retrive_next_token, - retrive_next_sibling, - topk, - spec_steps, - num_verify_tokens, - tree_mask_mode, - ) - return ( - tree_mask, - positions, - retrive_index, - retrive_next_token, - retrive_next_sibling, - draft_tokens, - ) - - -def test_build_tree_kernel_efficient(): - verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32) - score_list = [ - torch.tensor( - [ - [[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]], - [[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]], - ], - dtype=torch.float32, - device="cuda", - ), - torch.tensor( - [ - [ - [6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03], - [2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04], - [2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08], - [1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05], - ], - [ - [8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02], - [2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04], - [2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05], - [1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08], - ], - ], - dtype=torch.float32, - device="cuda", - ), - torch.tensor( - [ - [ - [6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06], - [2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04], - [1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05], - [1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05], - ], - [ - [2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03], - [6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05], - [4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04], - [7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04], - ], - ], - dtype=torch.float32, - device="cuda", - ), - torch.tensor( - [ - [ - [6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06], - [1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03], - [2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05], - [1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05], - ], - [ - [5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04], - [1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03], - [6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03], - [1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04], - ], - ], - dtype=torch.float32, - device="cuda", - ), - ] - token_list = [ - torch.tensor( - [[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]], - dtype=torch.int64, - device="cuda", - ), - torch.tensor( - [ - [ - 29889, - 29974, - 29945, - 29900, - 29974, - 29922, - 29930, - 29958, - 29889, - 29974, - 29930, - 29945, - 29974, - 29922, - 29930, - 29958, - ], - [ - 22550, - 4136, - 16492, - 8439, - 29871, - 2, - 3001, - 13, - 2, - 13, - 29906, - 29946, - 2, - 13, - 29871, - 259, - ], - ], - device="cuda", - ), - torch.tensor( - [ - [ - 29946, - 29945, - 29953, - 29906, - 29896, - 29945, - 29900, - 29906, - 29896, - 29945, - 29906, - 29953, - 29896, - 29945, - 29906, - 29946, - ], - [ - 29871, - 2, - 29901, - 29889, - 29871, - 2, - 395, - 259, - 29901, - 29871, - 2, - 29889, - 3001, - 1234, - 7146, - 2186, - ], - ], - device="cuda", - ), - torch.tensor( - [ - [ - 29946, - 29974, - 29945, - 29930, - 29889, - 29922, - 29974, - 29930, - 29974, - 29946, - 29930, - 29922, - 29889, - 29974, - 29945, - 29922, - ], - [ - 29941, - 29906, - 2, - 29946, - 29871, - 450, - 319, - 14990, - 29946, - 29941, - 2, - 29906, - 29871, - 2, - 3001, - 13, - ], - ], - device="cuda", - ), - ] - parents_list = [ - torch.tensor( - [[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda" - ), - torch.tensor([[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda"), - torch.tensor( - [[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda" - ), - torch.tensor( - [[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda" - ), - ] - seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda") - topk = 4 - depth = 4 - num_draft_token = 8 - - ( - tree_mask, - position, - retrive_index, - retrive_next_token, - retrive_next_sibling, - draft_tokens, - ) = build_tree_kernel_efficient( - verified_id=verified_id, - score_list=score_list, - token_list=token_list, - parents_list=parents_list, - seq_lens=seq_lens, - seq_lens_sum=torch.sum(seq_lens).item(), - topk=topk, - spec_steps=depth, - num_verify_tokens=num_draft_token, - ) - - print("=========== build tree kernel efficient ==========") - print(f"{tree_mask=}") - print(f"{position=}") - print(f"{retrive_index=}") - print(f"{retrive_next_token=}") - print(f"{retrive_next_sibling=}") - print(f"{draft_tokens=}") - assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14] - assert retrive_index.tolist() == [ - [0, 1, 2, 3, 4, 5, 6, 7], - [8, 9, 10, 11, 12, 13, 14, 15], - ] - assert retrive_next_token.tolist() == [ - [1, 3, 4, 5, 6, 7, -1, -1], - [1, 2, -1, 6, -1, -1, 7, -1], - ] - assert retrive_next_sibling.tolist() == [ - [-1, 2, -1, -1, -1, -1, -1, -1], - [-1, -1, 3, 4, 5, -1, -1, -1], - ] - assert draft_tokens.tolist() == [ - 29974, - 29896, - 29906, - 29889, - 29974, - 29946, - 29896, - 29946, - 13, - 13, - 22550, - 4136, - 16492, - 8439, - 29871, - 29941, - ] - - -if __name__ == "__main__": - test_build_tree_kernel_efficient() diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index a6d5582c3..b538a4bf8 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -276,11 +276,9 @@ class EAGLEDraftCudaGraphRunner: return graph, out def _postprocess_output_to_raw_bs(self, out, raw_bs): - score_list, token_list, parents_list = out - score_list = [x[:raw_bs] for x in score_list] - token_list = [x[:raw_bs] for x in token_list] - parents_list = [x[:raw_bs] for x in parents_list] - return (score_list, token_list, parents_list) + # Keep the variables name for readability + parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out) + return parent_list, top_scores_index, draft_tokens def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index 982343ce9..b068abd4e 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -1,8 +1,7 @@ from __future__ import annotations -import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F @@ -19,7 +18,6 @@ from sglang.srt.model_executor.forward_batch_info import ( ) from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import get_global_server_args -from sglang.srt.speculative.build_eagle_tree import TreeMaskMode from sglang.srt.speculative.spec_utils import ( SIMULATE_ACC_LEN, generate_simulated_accept_index, @@ -286,110 +284,6 @@ class EagleVerifyInputV2Mixin: return predict, accept_length, accept_index -def build_tree_kernel_efficient_tmp( - verified_id: torch.Tensor, - parent_list: List[torch.Tensor], - top_scores_index: torch.Tensor, - draft_tokens: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - topk: int, - spec_steps: int, - num_verify_tokens: int, - tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK, - tree_mask_buf: Optional[torch.Tensor] = None, - position_buf: Optional[torch.Tensor] = None, -): - # TODO(lsyin): make it compatible with default code path - # TODO(lsyin): support cuda graph graph padding for eagle - draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() - - # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens - bs = seq_lens.numel() - device = seq_lens.device - # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened) - # where each row indicates the attending pattern of each draft token - # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed) - if tree_mask_buf is not None: - tree_mask = tree_mask_buf - if tree_mask_mode == TreeMaskMode.QLEN_ONLY: - tree_mask.fill_(True) - elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: - tree_mask.fill_(0) - elif tree_mask_mode == TreeMaskMode.FULL_MASK: - tree_mask.fill_(True) - else: - raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") - elif tree_mask_mode == TreeMaskMode.QLEN_ONLY: - tree_mask = torch.full( - (num_verify_tokens * bs * num_verify_tokens,), - True, - dtype=torch.bool, - device=device, - ) - elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: - packed_dtypes = [torch.uint8, torch.uint16, torch.uint32] - packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8))) - tree_mask = torch.zeros( - (num_verify_tokens * bs,), - dtype=packed_dtypes[packed_dtype_idx], - device=device, - ) - elif tree_mask_mode == TreeMaskMode.FULL_MASK: - tree_mask = torch.full( - ( - seq_lens_sum * num_verify_tokens - + num_verify_tokens * num_verify_tokens * bs, - ), - True, - device=device, - ) - else: - raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") - - # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel` - retrive_buf = torch.full( - (3, bs, num_verify_tokens), -1, device=device, dtype=torch.long - ) - retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf - # position: where each token belongs to - # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7 - # then, positions = [7, 8, 8, 9] - if position_buf is not None: - positions = position_buf - else: - positions = torch.empty( - (bs * num_verify_tokens,), device=device, dtype=torch.long - ) - - from sgl_kernel import ( - build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, - ) - - sgl_build_tree_kernel_efficient( - parent_list, - top_scores_index, - seq_lens, - tree_mask, - positions, - retrive_index, - retrive_next_token, - retrive_next_sibling, - topk, - spec_steps, - num_verify_tokens, - tree_mask_mode, - ) - return ( - tree_mask, - positions, - retrive_index, - retrive_next_token, - retrive_next_sibling, - draft_tokens, - ) - - @torch.compile(dynamic=True) def select_top_k_tokens_tmp( i: int, diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py new file mode 100644 index 000000000..f3b40c875 --- /dev/null +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -0,0 +1,138 @@ +import math +from enum import IntEnum +from typing import List, Optional + +import torch + +from sglang.srt.utils import is_cuda, is_hip + +if is_cuda() or is_hip(): + from sgl_kernel import ( + build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, + ) + + +def organize_draft_results( + score_list: List[torch.Tensor], + token_list: List[torch.Tensor], + parents_list: List[torch.Tensor], + num_draft_token: int, +): + score_list = torch.cat(score_list, dim=1).flatten(1) + ss_token_list = torch.cat(token_list, dim=1) + top_scores = torch.topk(score_list, num_draft_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + + if len(parents_list) > 1: + parent_list = torch.cat(parents_list[:-1], dim=1) + else: + batch_size = parents_list[0].shape[0] + parent_list = torch.empty(batch_size, 0, device=parents_list[0].device) + + return parent_list, top_scores_index, draft_tokens + + +class TreeMaskMode(IntEnum): + FULL_MASK = 0 + QLEN_ONLY = 1 + QLEN_ONLY_BITPACKING = 2 + + +def build_tree_kernel_efficient( + verified_id: torch.Tensor, + parent_list: List[torch.Tensor], + top_scores_index: torch.Tensor, + draft_tokens: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + topk: int, + spec_steps: int, + num_verify_tokens: int, + tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK, + tree_mask_buf: Optional[torch.Tensor] = None, + position_buf: Optional[torch.Tensor] = None, +): + draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() + + # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens + bs = seq_lens.numel() + device = seq_lens.device + # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened) + # where each row indicates the attending pattern of each draft token + # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed) + if tree_mask_buf is not None: + tree_mask = tree_mask_buf + if tree_mask_mode == TreeMaskMode.QLEN_ONLY: + tree_mask.fill_(True) + elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: + tree_mask.fill_(0) + elif tree_mask_mode == TreeMaskMode.FULL_MASK: + tree_mask.fill_(True) + else: + raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") + elif tree_mask_mode == TreeMaskMode.QLEN_ONLY: + tree_mask = torch.full( + (num_verify_tokens * bs * num_verify_tokens,), + True, + dtype=torch.bool, + device=device, + ) + elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: + packed_dtypes = [torch.uint8, torch.uint16, torch.uint32] + packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8))) + tree_mask = torch.zeros( + (num_verify_tokens * bs,), + dtype=packed_dtypes[packed_dtype_idx], + device=device, + ) + elif tree_mask_mode == TreeMaskMode.FULL_MASK: + tree_mask = torch.full( + ( + seq_lens_sum * num_verify_tokens + + num_verify_tokens * num_verify_tokens * bs, + ), + True, + device=device, + ) + else: + raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") + + # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel` + retrive_buf = torch.full( + (3, bs, num_verify_tokens), -1, device=device, dtype=torch.long + ) + retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf + # position: where each token belongs to + # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7 + # then, positions = [7, 8, 8, 9] + if position_buf is not None: + positions = position_buf + else: + positions = torch.empty( + (bs * num_verify_tokens,), device=device, dtype=torch.long + ) + + sgl_build_tree_kernel_efficient( + parent_list, + top_scores_index, + seq_lens, + tree_mask, + positions, + retrive_index, + retrive_next_token, + retrive_next_sibling, + topk, + spec_steps, + num_verify_tokens, + tree_mask_mode, + ) + return ( + tree_mask, + positions, + retrive_index, + retrive_next_token, + retrive_next_sibling, + draft_tokens, + ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index f501f9d8b..a8461c999 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -28,7 +28,6 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, ) from sglang.srt.server_args import ServerArgs, get_global_server_args -from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, ) @@ -40,6 +39,10 @@ from sglang.srt.speculative.eagle_info import ( EagleVerifyInput, EagleVerifyOutput, ) +from sglang.srt.speculative.eagle_utils import ( + build_tree_kernel_efficient, + organize_draft_results, +) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_utils import ( assign_draft_cache_locs, @@ -677,7 +680,7 @@ class EAGLEWorker(TpModelWorker): forward_batch ) if can_cuda_graph: - score_list, token_list, parents_list = self.cuda_graph_runner.replay( + parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay( forward_batch ) else: @@ -686,7 +689,9 @@ class EAGLEWorker(TpModelWorker): # Initialize attention backend self.draft_attn_backend.init_forward_metadata(forward_batch) # Run forward steps - score_list, token_list, parents_list = self.draft_forward(forward_batch) + parent_list, top_scores_index, draft_tokens = self.draft_forward( + forward_batch + ) if batch.forward_mode.is_idle(): return EagleVerifyInput.create_idle_input( @@ -704,9 +709,9 @@ class EAGLEWorker(TpModelWorker): draft_tokens, ) = build_tree_kernel_efficient( spec_info.verified_id, - score_list, - token_list, - parents_list, + parent_list, + top_scores_index, + draft_tokens, batch.seq_lens, batch.seq_lens_sum, self.topk, @@ -795,7 +800,11 @@ class EAGLEWorker(TpModelWorker): topk_index = self.hot_token_id[topk_index] hidden_states = logits_output.hidden_states - return score_list, token_list, parents_list + parent_list, top_scores_index, draft_tokens = organize_draft_results( + score_list, token_list, parents_list, self.speculative_num_draft_tokens + ) + + return parent_list, top_scores_index, draft_tokens def clear_cache_pool(self): self.model_runner.req_to_token_pool.clear() diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index fb01eba53..3ab0784d6 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -12,15 +12,14 @@ from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.build_eagle_tree import TreeMaskMode from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_info_v2 import ( assign_extend_cache_locs, - build_tree_kernel_efficient_tmp, fill_accepted_out_cache_loc, fill_new_verified_id, select_top_k_tokens_tmp, ) +from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.utils.common import fast_topk, next_power_of_2 @@ -116,7 +115,7 @@ class EAGLEWorkerV2(EAGLEWorker): retrive_next_token, retrive_next_sibling, draft_tokens, - ) = build_tree_kernel_efficient_tmp( + ) = build_tree_kernel_efficient( draft_input.verified_id, parent_list, top_scores_index, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 682cf45b8..96289a3df 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -69,6 +69,7 @@ suites = { TestFile("test_chunked_prefill.py", 313), TestFile("test_create_kvindices.py", 2), TestFile("test_deterministic.py", 300), + TestFile("test_build_eagle_tree.py", 8), TestFile("test_eagle_infer_a.py", 370), TestFile("test_eagle_infer_b.py", 700), TestFile("test_eagle_infer_beta.py", 300), diff --git a/test/srt/test_build_eagle_tree.py b/test/srt/test_build_eagle_tree.py new file mode 100644 index 000000000..5372393da --- /dev/null +++ b/test/srt/test_build_eagle_tree.py @@ -0,0 +1,308 @@ +import unittest + +import torch + +from sglang.srt.speculative.eagle_utils import ( + build_tree_kernel_efficient, + organize_draft_results, +) + + +class TestBuildEagleTree(unittest.TestCase): + """Unit tests for build_eagle_tree functionality.""" + + def test_build_tree_kernel_efficient(self): + """Test the build_tree_kernel_efficient function with known inputs and expected outputs.""" + verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32) + score_list = [ + torch.tensor( + [ + [[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]], + [[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]], + ], + dtype=torch.float32, + device="cuda", + ), + torch.tensor( + [ + [ + [6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03], + [2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04], + [2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08], + [1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05], + ], + [ + [8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02], + [2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04], + [2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05], + [1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08], + ], + ], + dtype=torch.float32, + device="cuda", + ), + torch.tensor( + [ + [ + [6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06], + [2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04], + [1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05], + [1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05], + ], + [ + [2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03], + [6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05], + [4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04], + [7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04], + ], + ], + dtype=torch.float32, + device="cuda", + ), + torch.tensor( + [ + [ + [6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06], + [1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03], + [2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05], + [1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05], + ], + [ + [5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04], + [1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03], + [6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03], + [1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04], + ], + ], + dtype=torch.float32, + device="cuda", + ), + ] + token_list = [ + torch.tensor( + [[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]], + dtype=torch.int64, + device="cuda", + ), + torch.tensor( + [ + [ + 29889, + 29974, + 29945, + 29900, + 29974, + 29922, + 29930, + 29958, + 29889, + 29974, + 29930, + 29945, + 29974, + 29922, + 29930, + 29958, + ], + [ + 22550, + 4136, + 16492, + 8439, + 29871, + 2, + 3001, + 13, + 2, + 13, + 29906, + 29946, + 2, + 13, + 29871, + 259, + ], + ], + device="cuda", + ), + torch.tensor( + [ + [ + 29946, + 29945, + 29953, + 29906, + 29896, + 29945, + 29900, + 29906, + 29896, + 29945, + 29906, + 29953, + 29896, + 29945, + 29906, + 29946, + ], + [ + 29871, + 2, + 29901, + 29889, + 29871, + 2, + 395, + 259, + 29901, + 29871, + 2, + 29889, + 3001, + 1234, + 7146, + 2186, + ], + ], + device="cuda", + ), + torch.tensor( + [ + [ + 29946, + 29974, + 29945, + 29930, + 29889, + 29922, + 29974, + 29930, + 29974, + 29946, + 29930, + 29922, + 29889, + 29974, + 29945, + 29922, + ], + [ + 29941, + 29906, + 2, + 29946, + 29871, + 450, + 319, + 14990, + 29946, + 29941, + 2, + 29906, + 29871, + 2, + 3001, + 13, + ], + ], + device="cuda", + ), + ] + parents_list = [ + torch.tensor( + [[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda" + ), + torch.tensor( + [[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda" + ), + torch.tensor( + [[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda" + ), + torch.tensor( + [[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda" + ), + ] + seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda") + topk = 4 + depth = 4 + num_draft_token = 8 + + parent_list, top_scores_index, draft_tokens = organize_draft_results( + score_list, token_list, parents_list, num_draft_token + ) + + ( + tree_mask, + position, + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + draft_tokens, + ) = build_tree_kernel_efficient( + verified_id=verified_id, + parent_list=parent_list, + top_scores_index=top_scores_index, + draft_tokens=draft_tokens, + seq_lens=seq_lens, + seq_lens_sum=torch.sum(seq_lens).item(), + topk=topk, + spec_steps=depth, + num_verify_tokens=num_draft_token, + ) + + # Verify expected outputs + self.assertEqual( + position.tolist(), + [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14], + "Position tensor does not match expected values", + ) + self.assertEqual( + retrieve_index.tolist(), + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + ], + "Retrieve index tensor does not match expected values", + ) + self.assertEqual( + retrieve_next_token.tolist(), + [ + [1, 3, 4, 5, 6, 7, -1, -1], + [1, 2, -1, 6, -1, -1, 7, -1], + ], + "Retrieve next token tensor does not match expected values", + ) + self.assertEqual( + retrieve_next_sibling.tolist(), + [ + [-1, 2, -1, -1, -1, -1, -1, -1], + [-1, -1, 3, 4, 5, -1, -1, -1], + ], + "Retrieve next sibling tensor does not match expected values", + ) + self.assertEqual( + draft_tokens.tolist(), + [ + 29974, + 29896, + 29906, + 29889, + 29974, + 29946, + 29896, + 29946, + 13, + 13, + 22550, + 4136, + 16492, + 8439, + 29871, + 29941, + ], + "Draft tokens tensor does not match expected values", + ) + + +if __name__ == "__main__": + unittest.main()