Integrate triton moe kernel (#7689)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -1737,6 +1737,7 @@ def fused_moe(
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
@@ -1822,6 +1823,7 @@ def fused_moe(
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||
|
||||
import importlib
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
@@ -29,8 +31,15 @@ from sglang.srt.utils import (
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
if has_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
|
||||
@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, use_triton_kernels: bool = False):
|
||||
super().__init__()
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
||||
if self.use_triton_kernels:
|
||||
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||
),
|
||||
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight_n, w2_weight_k = (
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
)
|
||||
if self.use_triton_kernels:
|
||||
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||
),
|
||||
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
@@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
if apply_router_weight_on_input:
|
||||
assert (
|
||||
topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
x = x * topk_weights.to(x.dtype)
|
||||
topk_weights = torch.ones_like(
|
||||
topk_weights, dtype=torch.float32
|
||||
) # topk_weights must be FP32 (float32)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=(
|
||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
else:
|
||||
return fused_experts(
|
||||
if self.use_triton_kernels:
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
if apply_router_weight_on_input:
|
||||
assert (
|
||||
topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
x = x * topk_weights.to(x.dtype)
|
||||
topk_weights = torch.ones_like(
|
||||
topk_weights, dtype=torch.float32
|
||||
) # topk_weights must be FP32 (float32)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
else:
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -475,9 +506,13 @@ class FusedMoE(torch.nn.Module):
|
||||
self.inplace = inplace
|
||||
self.no_combine = no_combine
|
||||
|
||||
self.use_triton_kernels = (
|
||||
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
||||
)
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
UnquantizedFusedMoEMethod()
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||
self.use_triton_kernels
|
||||
)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
@@ -597,6 +632,8 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
if not self.use_presharded_weights:
|
||||
if self.use_triton_kernels:
|
||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
@@ -630,6 +667,8 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
if not self.use_presharded_weights:
|
||||
if self.use_triton_kernels:
|
||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
@@ -716,6 +755,8 @@ class FusedMoE(torch.nn.Module):
|
||||
# should be whatever dimension intermediate_size is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if self.use_triton_kernels:
|
||||
is_transposed = True
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
from triton_kernels.matmul_ogs import matmul_ogs
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||
|
||||
from sglang.srt.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if not renormalize:
|
||||
gating_output = torch.softmax(gating_output, dim=-1)
|
||||
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
# This is a triton implementation of the fused_experts function
|
||||
def triton_kernel_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
routing_data: RoutingData,
|
||||
gather_indx: GatherIndx,
|
||||
scatter_indx: ScatterIndx,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
||||
assert per_channel_quant == False, "per_channel_quant is not supported"
|
||||
assert expert_map == None, "expert_map is not supported"
|
||||
assert w1_scale == None, "w1_scale is not supported"
|
||||
assert w2_scale == None, "w2_scale is not supported"
|
||||
assert a1_scale == None, "a1_scale is not supported"
|
||||
assert a2_scale == None, "a2_scale is not supported"
|
||||
assert block_shape == None, "block_shape is not supported"
|
||||
|
||||
# type check
|
||||
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
||||
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
|
||||
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
|
||||
|
||||
# Shape check
|
||||
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
||||
assert (
|
||||
hidden_states.shape[-1] == w1.shape[-2]
|
||||
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
||||
assert (
|
||||
w2.shape[-1] == w1.shape[1]
|
||||
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
||||
|
||||
# feature check
|
||||
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E, _, N = w1.shape
|
||||
n_expts_act = routing_data.n_expts_act
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
# consistent with default implementation
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * n_expts_act, N // 2), device="cuda", dtype=dtype
|
||||
)
|
||||
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
None,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
|
||||
)
|
||||
|
||||
if activation == "silu":
|
||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
elif activation == "gelu":
|
||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
intermediate_cache3 = matmul_ogs(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
None,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def triton_kernel_moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="forward_cuda_triton",
|
||||
op_func=triton_kernel_moe_forward,
|
||||
mutates_args=[],
|
||||
fake_impl=triton_kernel_moe_forward_fake,
|
||||
)
|
||||
@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"triton_attention_reduce_in_fp32",
|
||||
"num_reserved_decode_tokens",
|
||||
"weight_loader_disable_mmap",
|
||||
"enable_triton_kernel_moe",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
|
||||
@@ -222,6 +222,7 @@ class ServerArgs:
|
||||
disable_chunked_prefix_cache: bool = False
|
||||
disable_fast_image_processor: bool = False
|
||||
enable_return_hidden_states: bool = False
|
||||
enable_triton_kernel_moe: bool = False
|
||||
warmups: Optional[str] = None
|
||||
|
||||
# Debug tensor dumps
|
||||
@@ -1554,6 +1555,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable returning hidden states with responses.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-triton-kernel-moe",
|
||||
action="store_true",
|
||||
help="Use triton moe grouped gemm kernel.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmups",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user