From fad315cb8e6a52c60b60cfef74ee70ec9fb8c3ae Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 9 Feb 2025 07:28:34 +0800 Subject: [PATCH] fix EAGLE 2 non greedy case (#3407) Co-authored-by: Ying Sheng --- ...nchmark_vllm_vs_sglang_fused_moe_triton.py | 4 +- .../srt/model_executor/cuda_graph_runner.py | 3 + python/sglang/srt/speculative/eagle_utils.py | 85 ++++++++++++++----- python/sglang/srt/speculative/eagle_worker.py | 1 + 4 files changed, 71 insertions(+), 22 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py index 4edb2dff8..6a4605eb5 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -54,7 +54,9 @@ def get_model_config(model_name: str, tp_size: int): ): block_shape = config.quantization_config["weight_block_size"] assert len(block_shape) == 2 - assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1" + assert ( + vllm_version_num >= 66 + ), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1" shape_configs = { "num_experts": E, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1f5e8e851..978a772f1 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -462,8 +462,11 @@ class CudaGraphRunner: ), positions=None, retrive_index=None, + retrive_next_token=None, + retrive_next_sibling=None, retrive_cum_len=None, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + spec_steps=self.model_runner.server_args.speculative_num_steps, capture_hidden_mode=CaptureHiddenMode.FULL, ) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 099c71cfb..7ea1ea9b8 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -4,6 +4,7 @@ import dataclasses from typing import TYPE_CHECKING, List import torch +import torch.nn.functional as F import triton import triton.language as tl @@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode -from sglang.srt.speculative.build_eagle_tree import build_tree_kernel +from sglang.srt.speculative.build_eagle_tree import ( + build_tree_kernel, + build_tree_kernel_efficient, +) +from sglang.srt.utils import is_cuda_available + +if is_cuda_available(): + from sgl_kernel import tree_speculative_sampling_target_only if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -160,8 +168,11 @@ class EagleVerifyInput: custom_mask: torch.Tensor positions: torch.Tensor retrive_index: torch.Tensor + retrive_next_token: torch.Tensor + retrive_next_sibling: torch.Tensor retrive_cum_len: torch.Tensor draft_token_num: int + spec_steps: int capture_hidden_mode: CaptureHiddenMode @classmethod @@ -175,10 +186,45 @@ class EagleVerifyInput: seq_lens_sum: int, topk: int, spec_steps: int, - num_verify_token: int, + num_verify_tokens: int, + is_all_greedy: bool, ): - tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = ( - build_tree_kernel( + if is_all_greedy: + tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = ( + build_tree_kernel( + verified_id, + score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk + token_list, + parents_list, + seq_lens, + seq_lens_sum, + topk, + spec_steps, + num_verify_tokens, + ) + ) + + return cls( + draft_tokens, + tree_mask, + position, + retrive_index, + None, + None, + retrive_cum_len, + num_verify_tokens, + spec_steps, + CaptureHiddenMode.FULL, + ) + else: + ( + tree_mask, + position, + retrive_index, + retrive_next_token, + retrive_next_sibling, + draft_tokens, + ) = build_tree_kernel_efficient( verified_id, score_list, token_list, @@ -187,18 +233,21 @@ class EagleVerifyInput: seq_lens_sum, topk, spec_steps, - num_verify_token, + num_verify_tokens, + ) + + return cls( + draft_tokens, + tree_mask, + position, + retrive_index, + retrive_next_token, + retrive_next_sibling, + None, + num_verify_tokens, + spec_steps, + CaptureHiddenMode.FULL, ) - ) - return cls( - draft_tokens, - tree_mask, - position, - retrive_index, - retrive_cum_len, - num_verify_token, - CaptureHiddenMode.FULL, - ) def prepare_for_verify(self, batch: ScheduleBatch): batch.input_ids = self.draft_token @@ -313,12 +362,6 @@ class EagleVerifyInput: uniform_samples=coins, target_probs=target_probs, draft_probs=draft_probs, - threshold_single=global_server_args_dict[ - "speculative_accept_threshold_single" - ], - threshold_acc=global_server_args_dict[ - "speculative_accept_threshold_acc" - ], deterministic=True, ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 6d84cc305..c640be8c6 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, self.server_args.speculative_num_draft_tokens, + batch.sampling_info.is_all_greedy, ) # Free cache locations