From 110a65989b8dc5d08a5ec15b60a0042af57500cb Mon Sep 17 00:00:00 2001 From: datdo-msft <131494842+datdo-msft@users.noreply.github.com> Date: Fri, 22 Aug 2025 11:14:43 -0700 Subject: [PATCH] [MTP] Force greedy sampling on AMD (#9127) --- python/sglang/srt/speculative/eagle_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index b02319584..14450e9b1 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial") TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly +TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals() + @dataclass class EagleDraftInput: @@ -423,8 +425,15 @@ class EagleVerifyInput: logits=logits_output.next_token_logits, vocab_mask=vocab_mask ) - # Sample tokens - if batch.sampling_info.is_all_greedy: + # Sample tokens. Force greedy sampling on AMD + is_all_greedy = sampling_info.is_all_greedy + if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE): + logger.warning( + "Tree speculative sampling kernel unavailable (likely AMD/HIP build). " + "Falling back to greedy verification." + ) + + if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE: target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = target_predict.reshape(bs, self.draft_token_num)