From 53a525bf33564fc164365eb7eab5d5e3a8b061df Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 16 Jun 2025 07:25:59 -0700 Subject: [PATCH] [Eagle] Fix kernel call after updating speculative sampling kernels (#7231) --- docker/Dockerfile.blackwell | 2 +- python/pyproject.toml | 2 +- python/sglang/srt/entrypoints/engine.py | 2 +- .../srt/speculative/build_eagle_tree.py | 2 +- python/sglang/srt/speculative/eagle_utils.py | 33 +++++++------------ python/sglang/srt/speculative/eagle_worker.py | 2 +- test/srt/test_fa3.py | 14 ++++---- 7 files changed, 24 insertions(+), 33 deletions(-) diff --git a/docker/Dockerfile.blackwell b/docker/Dockerfile.blackwell index 159a814d1..889f85da1 100644 --- a/docker/Dockerfile.blackwell +++ b/docker/Dockerfile.blackwell @@ -20,7 +20,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ RUN pip3 install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128 --break-system-packages -RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.1.8.post2/sgl_kernel-0.1.8.post2+cu128-cp39-abi3-manylinux2014_x86_64.whl --break-system-packages \ +RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.1.9/sgl_kernel-0.1.9+cu128-cp39-abi3-manylinux2014_x86_64.whl --break-system-packages \ && pip3 install setuptools==75.0.0 wheel scikit-build-core --break-system-packages RUN git clone --depth=1 https://github.com/sgl-project/sglang.git \ diff --git a/python/pyproject.toml b/python/pyproject.toml index c8a8ffc4a..fee84f015 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -49,7 +49,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.1.8.post2", + "sgl-kernel==0.1.9", "flashinfer_python==0.2.6.post1", "torch==2.7.1", "torchaudio==2.7.1", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index e53ad1a3b..45e159d63 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sgl-kernel", - "0.1.8.post2", + "0.1.9", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index 1f0e0fcb7..c6b853cc6 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -92,7 +92,7 @@ def build_tree_kernel_efficient( sgl_build_tree_kernel_efficient( parent_list, top_scores_index, - seq_lens.to(torch.int32), + seq_lens, tree_mask, positions, retrive_index, diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index cedc2ee88..2657d8351 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import ( ) from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode -from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 +from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 if is_cuda(): from sgl_kernel import ( @@ -32,6 +32,7 @@ if is_cuda(): tree_speculative_sampling_target_only, verify_tree_greedy, ) + from sgl_kernel.top_k import fast_topk elif is_hip(): from sgl_kernel import verify_tree_greedy @@ -327,11 +328,11 @@ class EagleVerifyInput: predicts=predict, # mutable accept_index=accept_index, # mutable accept_token_num=accept_length, # mutable - candidates=candidates.to(torch.int32), - retrive_index=self.retrive_index.to(torch.int32), - retrive_next_token=self.retrive_next_token.to(torch.int32), - retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), - target_predict=target_predict.to(torch.int32), + candidates=candidates, + retrive_index=self.retrive_index, + retrive_next_token=self.retrive_next_token, + retrive_next_sibling=self.retrive_next_sibling, + target_predict=target_predict, ) else: # apply temperature and get target probs @@ -370,12 +371,12 @@ class EagleVerifyInput: predicts=predict, # mutable accept_index=accept_index, # mutable accept_token_num=accept_length, # mutable - candidates=candidates.to(torch.int32), - retrive_index=self.retrive_index.to(torch.int32), - retrive_next_token=self.retrive_next_token.to(torch.int32), - retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), + candidates=candidates, + retrive_index=self.retrive_index, + retrive_next_token=self.retrive_next_token, + retrive_next_sibling=self.retrive_next_sibling, uniform_samples=coins, - # uniform_samples_for_final_sampling=coins_for_final_sampling, + uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, threshold_single=global_server_args_dict[ @@ -1005,16 +1006,6 @@ def select_top_k_tokens( return input_ids, hidden_states, scores, tree_info -def fast_topk_torch(values, topk, dim): - if topk == 1: - # Use max along the specified dimension to get both value and index - max_value, max_index = torch.max(values, dim=dim) - return max_value.unsqueeze(1), max_index.unsqueeze(1) - else: - # Use topk for efficiency with larger k values - return torch.topk(values, topk, dim=dim) - - def _generate_simulated_accept_index( accept_index, predict, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index e42515dce..83bea359c 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -828,7 +828,7 @@ def load_token_map(token_map_path: str) -> List[int]: ) token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) hot_token_id = torch.load(token_map_path, weights_only=True) - return torch.tensor(hot_token_id, dtype=torch.int32) + return torch.tensor(hot_token_id, dtype=torch.int64) @torch.compile(dynamic=True) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index c43196571..45ad87e7d 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -143,7 +143,7 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): args.extend( [ "--cuda-graph-max-bs", - "2", + "4", "--speculative-algorithm", "EAGLE3", "--speculative-draft", @@ -169,7 +169,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): model = DEFAULT_MODEL_NAME_FOR_TEST accuracy_threshold = 0.65 speculative_decode = True - spec_decode_threshold = 1.5 + spec_decode_threshold = 1.6 @classmethod def get_server_args(cls): @@ -177,7 +177,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): args.extend( [ "--cuda-graph-max-bs", - "2", + "4", "--speculative-algorithm", "EAGLE3", "--speculative-draft", @@ -201,7 +201,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): model = DEFAULT_MODEL_NAME_FOR_TEST_MLA accuracy_threshold = 0.60 speculative_decode = True - spec_decode_threshold = 1.5 + spec_decode_threshold = 2.5 @classmethod def get_server_args(cls): @@ -209,7 +209,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): args.extend( [ "--cuda-graph-max-bs", - "2", + "4", "--speculative-algorithm", "EAGLE", "--speculative-draft", @@ -233,7 +233,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest): model = DEFAULT_MODEL_NAME_FOR_TEST_MLA accuracy_threshold = 0.60 speculative_decode = True - spec_decode_threshold = 1.5 + spec_decode_threshold = 2.95 @classmethod def get_server_args(cls): @@ -241,7 +241,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest): args.extend( [ "--cuda-graph-max-bs", - "2", + "4", "--speculative-algorithm", "EAGLE", "--speculative-draft",