Add activation parameters to fused_moe (#3170)
This commit is contained in:
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.activation = activation
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||
@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
|
||||
if self.grouped_gemm_runner is None:
|
||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
w2_weights = layer.w2_weight[topk_ids]
|
||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||
x1 = F.silu(x1)
|
||||
if activation == "silu":
|
||||
x1 = F.silu(x1)
|
||||
elif activation == "gelu":
|
||||
x1 = F.gelu(x1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation=}")
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
@@ -58,6 +64,7 @@ def moe_forward_native(
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
@@ -84,6 +91,13 @@ def moe_forward_native(
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
if activation == "silu":
|
||||
act = SiluAndMul()
|
||||
elif activation == "gelu":
|
||||
act = GeluAndMul()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation=}")
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
@@ -96,7 +110,7 @@ def moe_forward_native(
|
||||
layer_w2_weight = layer.w2_weight[i]
|
||||
|
||||
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
||||
gate_up = SiluAndMul()(gate_up)
|
||||
gate_up = act(gate_up)
|
||||
expert_out = F.linear(gate_up, layer_w2_weight)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
@@ -711,6 +711,7 @@ def inplace_fused_experts(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -726,6 +727,7 @@ def inplace_fused_experts(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
True,
|
||||
activation,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -767,6 +770,7 @@ def outplace_fused_experts(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -782,6 +786,7 @@ def outplace_fused_experts(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False,
|
||||
activation,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -824,6 +830,7 @@ def fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -839,6 +846,7 @@ def fused_experts(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
@@ -855,6 +863,7 @@ def fused_experts(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
@@ -872,6 +881,7 @@ def fused_experts_impl(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -986,7 +996,12 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
if activation == "silu":
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation=}")
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
@@ -1042,6 +1057,7 @@ def fused_moe(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
@@ -1111,6 +1127,7 @@ def fused_moe(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
|
||||
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
return self.forward(
|
||||
x=x,
|
||||
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
import ater
|
||||
from ater.fused_moe import fused_experts_ck
|
||||
|
||||
assert activation == "silu", f"{activation=} is not supported."
|
||||
|
||||
return fused_experts_ck(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.activation = activation
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
|
||||
@@ -763,8 +763,8 @@ class Fp8MoEMethod:
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
@@ -785,6 +785,8 @@ class Fp8MoEMethod:
|
||||
import ater
|
||||
from ater.fused_moe import fused_experts_ck
|
||||
|
||||
assert activation == "silu", f"{activation=} is not supported."
|
||||
|
||||
return fused_experts_ck(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -815,6 +817,7 @@ class Fp8MoEMethod:
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
|
||||
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
activation="gelu",
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user