diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index b1032c83b..099c71cfb 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -177,29 +177,21 @@ class EagleVerifyInput: spec_steps: int, num_verify_token: int, ): - score_list = torch.cat(score_list, dim=1).flatten( - 1 - ) # b, n, topk; n= 1 + (num_steps-1) * self.topk - ss_token_list = torch.cat( - token_list, dim=1 - ) # b, (self.topk + (num_steps-1) * self.topk) - top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1) - top_scores_index = top_scores.indices - top_scores_index = torch.sort(top_scores_index).values - draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) - draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1) - parent_list = torch.cat(parents_list[:-1], dim=1) - tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( - parent_list, - top_scores_index, - seq_lens, - seq_lens_sum, - topk, - spec_steps, - num_verify_token, + tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = ( + build_tree_kernel( + verified_id, + score_list, + token_list, + parents_list, + seq_lens, + seq_lens_sum, + topk, + spec_steps, + num_verify_token, + ) ) return cls( - draft_tokens.flatten(), + draft_tokens, tree_mask, position, retrive_index,