Optimize topk operation in llama4 (#5128)
This commit is contained in:
@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
||||
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
||||
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
|
||||
router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
|
||||
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
|
||||
hidden_states.dtype
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user