Optimize topk operation in llama4 (#5128)

This commit is contained in:
fzyzcjy
2025-04-09 17:50:22 +08:00
committed by GitHub
parent 92823069c4
commit 86a876d883
4 changed files with 18 additions and 15 deletions

View File

@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
logger = logging.getLogger(__name__)
@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
hidden_states.dtype
)