diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 4e4ba9a1e..8015c18a0 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP -from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers +from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers logger = logging.getLogger(__name__) @@ -63,7 +63,7 @@ class Llama4MoE(nn.Module): topk: int, renormalize: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1) + router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1) router_scores_aK = torch.sigmoid(router_scores_aK.float()).to( hidden_states.dtype ) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 19fa1807c..10c9e54c2 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import ( from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient -from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2 +from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2 if is_cuda_available(): from sgl_kernel import ( @@ -772,16 +772,6 @@ def select_top_k_tokens( return input_ids, hidden_states, scores, tree_info -def fast_topk(values, topk, dim): - if topk == 1: - # Use max along the specified dimension to get both value and index - max_value, max_index = torch.max(values, dim=dim) - return max_value.unsqueeze(1), max_index.unsqueeze(1) - else: - # Use topk for efficiency with larger k values - return torch.topk(values, topk, dim=dim) - - def _generate_simulated_accept_index( accept_index, predict, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 514603424..9967cf6ac 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import ( EagleVerifyInput, EagleVerifyOutput, assign_draft_cache_locs, - fast_topk, select_top_k_tokens, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available +from sglang.srt.utils import ( + empty_context, + fast_topk, + get_available_gpu_memory, + is_cuda_available, +) if is_cuda_available(): from sgl_kernel import segment_packbits diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index b43fe4273..d68fa489b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1819,3 +1819,12 @@ class DeepEPMode(Enum): return DeepEPMode.low_latency else: return DeepEPMode.normal + + +def fast_topk(values, topk, dim): + if topk == 1: + # Use max along the specified dimension to get both value and index + return torch.max(values, dim=dim, keepdim=True) + else: + # Use topk for efficiency with larger k values + return torch.topk(values, topk, dim=dim)