diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py index 57b7f20f0..d0f90f2d8 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -1,21 +1,25 @@ # Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035 -from typing import Optional + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional import torch from sgl_kernel import gelu_and_mul, silu_and_mul from triton_kernels.matmul_ogs import matmul_ogs -from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing +from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx from sglang.srt.utils import direct_register_custom_op +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + def triton_kernel_moe_forward( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, + topk_output: TopKOutput, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, @@ -30,9 +34,8 @@ def triton_kernel_moe_forward( block_shape: Optional[list[int]] = None, ) -> torch.Tensor: - if not renormalize: - gating_output = torch.softmax(gating_output, dim=-1) - routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize) + assert topk_output.format.is_triton_kernel() + routing_data, gather_idx, scatter_idx = topk_output return triton_kernel_fused_experts( hidden_states, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 253c269b6..475066a1c 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -15,7 +15,8 @@ from __future__ import annotations import math -from typing import Callable, NamedTuple, Optional +from enum import Enum, auto +from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable import torch import torch.nn.functional as F @@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import ( ExpertLocationDispatchInfo, topk_ids_logical_to_physical, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -37,6 +39,12 @@ from sglang.srt.utils import ( is_npu, ) +try: + from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing +except ImportError: + pass + + _is_cuda = is_cuda() _is_hip = is_hip() _is_cpu = is_cpu() @@ -58,16 +66,59 @@ if _is_npu: import torch_npu -class TopKOutput(NamedTuple): +# -------------------------------- TopKOutput --------------------------------------- + + +class TopKOutputFormat(Enum): + STANDARD = auto() + TRITON_KERNEL = auto() + + def is_standard(self) -> bool: + return self == TopKOutputFormat.STANDARD + + def is_triton_kernel(self) -> bool: + return self == TopKOutputFormat.TRITON_KERNEL + + +@runtime_checkable +class TopKOutput(Protocol): + """Protocol for top-k outputs in different formats.""" + + @property + def format(self) -> TopKOutputFormat: + """The format of the output.""" + ... + + +class StandardTopKOutput(NamedTuple): + """Standard top-k output format.""" + topk_weights: torch.Tensor topk_ids: torch.Tensor router_logits: torch.Tensor + @property + def format(self) -> TopKOutputFormat: + return TopKOutputFormat.STANDARD + + +class TritonKernelTopKOutput(NamedTuple): + """Triton kernel top-k output format.""" + + routing_data: RoutingData + gather_indx: GatherIndx + scatter_indx: ScatterIndx + + @property + def format(self) -> TopKOutputFormat: + return TopKOutputFormat.TRITON_KERNEL + + +# -------------------------------- TopK --------------------------------------- + class TopK(CustomOp): - # TODO(ch-wan): support triton_kernels - def __init__( self, top_k: int, @@ -97,6 +148,8 @@ class TopK(CustomOp): self.correction_bias = correction_bias self.routed_scaling_factor = routed_scaling_factor + self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] + def forward_native( self, hidden_states: torch.Tensor, @@ -131,23 +184,29 @@ class TopK(CustomOp): num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ) -> TopKOutput: - torch_native = False - return select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - custom_routing_function=self.custom_routing_function, - correction_bias=self.correction_bias, - torch_native=torch_native, - routed_scaling_factor=self.routed_scaling_factor, - num_token_non_padded=num_token_non_padded, - expert_location_dispatch_info=expert_location_dispatch_info, - ) + if self.use_triton_kernels: + routing_data, gather_idx, scatter_idx = routing( + router_logits, self.top_k, self.renormalize + ) + return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) + else: + torch_native = False + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + torch_native=torch_native, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) def forward_cpu( self, @@ -217,6 +276,9 @@ class TopK(CustomOp): ) +# ------------------------------- TopK implementation ------------------------------------- + + def fused_topk_torch_native( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -680,4 +742,4 @@ def select_experts( get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) - return TopKOutput(topk_weights, topk_ids, router_logits) + return StandardTopKOutput(topk_weights, topk_ids, router_logits) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 121d5b714..a307fcc11 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -130,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): super().__init__() self.use_triton_kernels = use_triton_kernels + self.triton_kernel_moe_forward = None + if torch.cuda.is_available() and has_triton_kernels: + from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_forward as _tk_forward, + ) + + self.triton_kernel_moe_forward = _tk_forward + def create_weights( self, layer: torch.nn.Module, @@ -229,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) -> torch.Tensor: if self.use_triton_kernels: - # TODO(ch-wan): re-enable the Triton kernel - raise NotImplementedError("The Triton kernel is temporarily disabled.") - # return triton_kernel_moe_forward( - # hidden_states=x, - # w1=layer.w13_weight, - # w2=layer.w2_weight, - # gating_output=router_logits, - # topk=top_k, - # renormalize=renormalize, - # ) + return self.triton_kernel_moe_forward( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_output=topk_output, + ) else: if _use_aiter: assert not no_combine, "unsupported"