Support triton kernels v3.4.0 for fused_moe (#8258)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: Cheng Wan <cwan@x.ai> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -1,21 +1,25 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
||||
from typing import Optional
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
from triton_kernels.matmul_ogs import matmul_ogs
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
||||
|
||||
from sglang.srt.utils import direct_register_custom_op
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
topk_output: TopKOutput,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -30,9 +34,8 @@ def triton_kernel_moe_forward(
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if not renormalize:
|
||||
gating_output = torch.softmax(gating_output, dim=-1)
|
||||
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
||||
assert topk_output.format.is_triton_kernel()
|
||||
routing_data, gather_idx, scatter_idx = topk_output
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
hidden_states,
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
|
||||
ExpertLocationDispatchInfo,
|
||||
topk_ids_logical_to_physical,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
@@ -37,6 +39,12 @@ from sglang.srt.utils import (
|
||||
is_npu,
|
||||
)
|
||||
|
||||
try:
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_cpu = is_cpu()
|
||||
@@ -58,16 +66,59 @@ if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
|
||||
class TopKOutput(NamedTuple):
|
||||
# -------------------------------- TopKOutput ---------------------------------------
|
||||
|
||||
|
||||
class TopKOutputFormat(Enum):
|
||||
STANDARD = auto()
|
||||
TRITON_KERNEL = auto()
|
||||
|
||||
def is_standard(self) -> bool:
|
||||
return self == TopKOutputFormat.STANDARD
|
||||
|
||||
def is_triton_kernel(self) -> bool:
|
||||
return self == TopKOutputFormat.TRITON_KERNEL
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TopKOutput(Protocol):
|
||||
"""Protocol for top-k outputs in different formats."""
|
||||
|
||||
@property
|
||||
def format(self) -> TopKOutputFormat:
|
||||
"""The format of the output."""
|
||||
...
|
||||
|
||||
|
||||
class StandardTopKOutput(NamedTuple):
|
||||
"""Standard top-k output format."""
|
||||
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
router_logits: torch.Tensor
|
||||
|
||||
@property
|
||||
def format(self) -> TopKOutputFormat:
|
||||
return TopKOutputFormat.STANDARD
|
||||
|
||||
|
||||
class TritonKernelTopKOutput(NamedTuple):
|
||||
"""Triton kernel top-k output format."""
|
||||
|
||||
routing_data: RoutingData
|
||||
gather_indx: GatherIndx
|
||||
scatter_indx: ScatterIndx
|
||||
|
||||
@property
|
||||
def format(self) -> TopKOutputFormat:
|
||||
return TopKOutputFormat.TRITON_KERNEL
|
||||
|
||||
|
||||
# -------------------------------- TopK ---------------------------------------
|
||||
|
||||
|
||||
class TopK(CustomOp):
|
||||
|
||||
# TODO(ch-wan): support triton_kernels
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
@@ -97,6 +148,8 @@ class TopK(CustomOp):
|
||||
self.correction_bias = correction_bias
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -131,23 +184,29 @@ class TopK(CustomOp):
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
) -> TopKOutput:
|
||||
torch_native = False
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
torch_native=torch_native,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
if self.use_triton_kernels:
|
||||
routing_data, gather_idx, scatter_idx = routing(
|
||||
router_logits, self.top_k, self.renormalize
|
||||
)
|
||||
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||
else:
|
||||
torch_native = False
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
torch_native=torch_native,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
@@ -217,6 +276,9 @@ class TopK(CustomOp):
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------- TopK implementation -------------------------------------
|
||||
|
||||
|
||||
def fused_topk_torch_native(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -680,4 +742,4 @@ def select_experts(
|
||||
|
||||
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
||||
|
||||
return TopKOutput(topk_weights, topk_ids, router_logits)
|
||||
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
|
||||
|
||||
@@ -130,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
super().__init__()
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
if torch.cuda.is_available() and has_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward as _tk_forward,
|
||||
)
|
||||
|
||||
self.triton_kernel_moe_forward = _tk_forward
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -229,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_triton_kernels:
|
||||
# TODO(ch-wan): re-enable the Triton kernel
|
||||
raise NotImplementedError("The Triton kernel is temporarily disabled.")
|
||||
# return triton_kernel_moe_forward(
|
||||
# hidden_states=x,
|
||||
# w1=layer.w13_weight,
|
||||
# w2=layer.w2_weight,
|
||||
# gating_output=router_logits,
|
||||
# topk=top_k,
|
||||
# renormalize=renormalize,
|
||||
# )
|
||||
return self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
else:
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
|
||||
Reference in New Issue
Block a user