Support triton kernels v3.4.0 for fused_moe (#8258)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: Cheng Wan <cwan@x.ai> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -1,21 +1,25 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
||||||
from typing import Optional
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||||
from triton_kernels.matmul_ogs import matmul_ogs
|
from triton_kernels.matmul_ogs import matmul_ogs
|
||||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
||||||
|
|
||||||
from sglang.srt.utils import direct_register_custom_op
|
from sglang.srt.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
|
|
||||||
def triton_kernel_moe_forward(
|
def triton_kernel_moe_forward(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
topk_output: TopKOutput,
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
@@ -30,9 +34,8 @@ def triton_kernel_moe_forward(
|
|||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if not renormalize:
|
assert topk_output.format.is_triton_kernel()
|
||||||
gating_output = torch.softmax(gating_output, dim=-1)
|
routing_data, gather_idx, scatter_idx = topk_output
|
||||||
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
|
||||||
|
|
||||||
return triton_kernel_fused_experts(
|
return triton_kernel_fused_experts(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
@@ -15,7 +15,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, NamedTuple, Optional
|
from enum import Enum, auto
|
||||||
|
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
|
|||||||
ExpertLocationDispatchInfo,
|
ExpertLocationDispatchInfo,
|
||||||
topk_ids_logical_to_physical,
|
topk_ids_logical_to_physical,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -37,6 +39,12 @@ from sglang.srt.utils import (
|
|||||||
is_npu,
|
is_npu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
@@ -58,16 +66,59 @@ if _is_npu:
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
|
|
||||||
class TopKOutput(NamedTuple):
|
# -------------------------------- TopKOutput ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TopKOutputFormat(Enum):
|
||||||
|
STANDARD = auto()
|
||||||
|
TRITON_KERNEL = auto()
|
||||||
|
|
||||||
|
def is_standard(self) -> bool:
|
||||||
|
return self == TopKOutputFormat.STANDARD
|
||||||
|
|
||||||
|
def is_triton_kernel(self) -> bool:
|
||||||
|
return self == TopKOutputFormat.TRITON_KERNEL
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class TopKOutput(Protocol):
|
||||||
|
"""Protocol for top-k outputs in different formats."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> TopKOutputFormat:
|
||||||
|
"""The format of the output."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class StandardTopKOutput(NamedTuple):
|
||||||
|
"""Standard top-k output format."""
|
||||||
|
|
||||||
topk_weights: torch.Tensor
|
topk_weights: torch.Tensor
|
||||||
topk_ids: torch.Tensor
|
topk_ids: torch.Tensor
|
||||||
router_logits: torch.Tensor
|
router_logits: torch.Tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> TopKOutputFormat:
|
||||||
|
return TopKOutputFormat.STANDARD
|
||||||
|
|
||||||
|
|
||||||
|
class TritonKernelTopKOutput(NamedTuple):
|
||||||
|
"""Triton kernel top-k output format."""
|
||||||
|
|
||||||
|
routing_data: RoutingData
|
||||||
|
gather_indx: GatherIndx
|
||||||
|
scatter_indx: ScatterIndx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> TopKOutputFormat:
|
||||||
|
return TopKOutputFormat.TRITON_KERNEL
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------- TopK ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TopK(CustomOp):
|
class TopK(CustomOp):
|
||||||
|
|
||||||
# TODO(ch-wan): support triton_kernels
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
@@ -97,6 +148,8 @@ class TopK(CustomOp):
|
|||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
|
||||||
|
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -131,23 +184,29 @@ class TopK(CustomOp):
|
|||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
) -> TopKOutput:
|
) -> TopKOutput:
|
||||||
torch_native = False
|
if self.use_triton_kernels:
|
||||||
return select_experts(
|
routing_data, gather_idx, scatter_idx = routing(
|
||||||
hidden_states=hidden_states,
|
router_logits, self.top_k, self.renormalize
|
||||||
router_logits=router_logits,
|
)
|
||||||
top_k=self.top_k,
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
else:
|
||||||
renormalize=self.renormalize,
|
torch_native = False
|
||||||
topk_group=self.topk_group,
|
return select_experts(
|
||||||
num_expert_group=self.num_expert_group,
|
hidden_states=hidden_states,
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
router_logits=router_logits,
|
||||||
custom_routing_function=self.custom_routing_function,
|
top_k=self.top_k,
|
||||||
correction_bias=self.correction_bias,
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
torch_native=torch_native,
|
renormalize=self.renormalize,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
topk_group=self.topk_group,
|
||||||
num_token_non_padded=num_token_non_padded,
|
num_expert_group=self.num_expert_group,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
)
|
custom_routing_function=self.custom_routing_function,
|
||||||
|
correction_bias=self.correction_bias,
|
||||||
|
torch_native=torch_native,
|
||||||
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
num_token_non_padded=num_token_non_padded,
|
||||||
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
self,
|
self,
|
||||||
@@ -217,6 +276,9 @@ class TopK(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------- TopK implementation -------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def fused_topk_torch_native(
|
def fused_topk_torch_native(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@@ -680,4 +742,4 @@ def select_experts(
|
|||||||
|
|
||||||
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
||||||
|
|
||||||
return TopKOutput(topk_weights, topk_ids, router_logits)
|
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
|
||||||
|
|||||||
@@ -130,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_triton_kernels = use_triton_kernels
|
self.use_triton_kernels = use_triton_kernels
|
||||||
|
|
||||||
|
self.triton_kernel_moe_forward = None
|
||||||
|
if torch.cuda.is_available() and has_triton_kernels:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
|
triton_kernel_moe_forward as _tk_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.triton_kernel_moe_forward = _tk_forward
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -229,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
# TODO(ch-wan): re-enable the Triton kernel
|
return self.triton_kernel_moe_forward(
|
||||||
raise NotImplementedError("The Triton kernel is temporarily disabled.")
|
hidden_states=x,
|
||||||
# return triton_kernel_moe_forward(
|
w1=layer.w13_weight,
|
||||||
# hidden_states=x,
|
w2=layer.w2_weight,
|
||||||
# w1=layer.w13_weight,
|
topk_output=topk_output,
|
||||||
# w2=layer.w2_weight,
|
)
|
||||||
# gating_output=router_logits,
|
|
||||||
# topk=top_k,
|
|
||||||
# renormalize=renormalize,
|
|
||||||
# )
|
|
||||||
else:
|
else:
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
assert not no_combine, "unsupported"
|
assert not no_combine, "unsupported"
|
||||||
|
|||||||
Reference in New Issue
Block a user