diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index b02319584..14450e9b1 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial") TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly +TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals() + @dataclass class EagleDraftInput: @@ -423,8 +425,15 @@ class EagleVerifyInput: logits=logits_output.next_token_logits, vocab_mask=vocab_mask ) - # Sample tokens - if batch.sampling_info.is_all_greedy: + # Sample tokens. Force greedy sampling on AMD + is_all_greedy = sampling_info.is_all_greedy + if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE): + logger.warning( + "Tree speculative sampling kernel unavailable (likely AMD/HIP build). " + "Falling back to greedy verification." + ) + + if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE: target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = target_predict.reshape(bs, self.draft_token_num)