Add activation parameters to fused_moe (#3170)
This commit is contained in:
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.num_expert_group = num_expert_group
|
self.num_expert_group = num_expert_group
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||||
@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
assert self.activation == "silu"
|
||||||
|
|
||||||
if self.grouped_gemm_runner is None:
|
if self.grouped_gemm_runner is None:
|
||||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
|
||||||
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
|
|||||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||||
w2_weights = layer.w2_weight[topk_ids]
|
w2_weights = layer.w2_weight[topk_ids]
|
||||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||||
x1 = F.silu(x1)
|
if activation == "silu":
|
||||||
|
x1 = F.silu(x1)
|
||||||
|
elif activation == "gelu":
|
||||||
|
x1 = F.gelu(x1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported activation: {activation=}")
|
||||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||||
@@ -58,6 +64,7 @@ def moe_forward_native(
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
@@ -84,6 +91,13 @@ def moe_forward_native(
|
|||||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||||
|
|
||||||
|
if activation == "silu":
|
||||||
|
act = SiluAndMul()
|
||||||
|
elif activation == "gelu":
|
||||||
|
act = GeluAndMul()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported activation: {activation=}")
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
for i, num_tokens in enumerate(tokens_per_expert):
|
for i, num_tokens in enumerate(tokens_per_expert):
|
||||||
@@ -96,7 +110,7 @@ def moe_forward_native(
|
|||||||
layer_w2_weight = layer.w2_weight[i]
|
layer_w2_weight = layer.w2_weight[i]
|
||||||
|
|
||||||
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
||||||
gate_up = SiluAndMul()(gate_up)
|
gate_up = act(gate_up)
|
||||||
expert_out = F.linear(gate_up, layer_w2_weight)
|
expert_out = F.linear(gate_up, layer_w2_weight)
|
||||||
outputs.append(expert_out)
|
outputs.append(expert_out)
|
||||||
start_idx = end_idx
|
start_idx = end_idx
|
||||||
|
|||||||
@@ -711,6 +711,7 @@ def inplace_fused_experts(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str = "silu",
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -726,6 +727,7 @@ def inplace_fused_experts(
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
True,
|
True,
|
||||||
|
activation,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str = "silu",
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -767,6 +770,7 @@ def outplace_fused_experts(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str = "silu",
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -782,6 +786,7 @@ def outplace_fused_experts(
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
False,
|
False,
|
||||||
|
activation,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str = "silu",
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -824,6 +830,7 @@ def fused_experts(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -839,6 +846,7 @@ def fused_experts(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
activation,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
@@ -855,6 +863,7 @@ def fused_experts(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
activation,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
@@ -872,6 +881,7 @@ def fused_experts_impl(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -986,7 +996,12 @@ def fused_experts_impl(
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
if activation == "silu":
|
||||||
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
|
elif activation == "gelu":
|
||||||
|
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported activation: {activation=}")
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
invoke_fused_moe_kernel(
|
||||||
intermediate_cache2,
|
intermediate_cache2,
|
||||||
@@ -1042,6 +1057,7 @@ def fused_moe(
|
|||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
@@ -1111,6 +1127,7 @@ def fused_moe(
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
|
activation=activation,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.forward(
|
return self.forward(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
|
activation=activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
import ater
|
import ater
|
||||||
from ater.fused_moe import fused_experts_ck
|
from ater.fused_moe import fused_experts_ck
|
||||||
|
|
||||||
|
assert activation == "silu", f"{activation=} is not supported."
|
||||||
|
|
||||||
return fused_experts_ck(
|
return fused_experts_ck(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
use_presharded_weights: bool = False,
|
use_presharded_weights: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
custom_routing_function=self.custom_routing_function,
|
custom_routing_function=self.custom_routing_function,
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
|
activation=self.activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
|||||||
@@ -763,8 +763,8 @@ class Fp8MoEMethod:
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
@@ -785,6 +785,8 @@ class Fp8MoEMethod:
|
|||||||
import ater
|
import ater
|
||||||
from ater.fused_moe import fused_experts_ck
|
from ater.fused_moe import fused_experts_ck
|
||||||
|
|
||||||
|
assert activation == "silu", f"{activation=} is not supported."
|
||||||
|
|
||||||
return fused_experts_ck(
|
return fused_experts_ck(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@@ -815,6 +817,7 @@ class Fp8MoEMethod:
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale=(
|
w1_scale=(
|
||||||
layer.w13_weight_scale_inv
|
layer.w13_weight_scale_inv
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
|
|||||||
renormalize=False,
|
renormalize=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
|
activation="gelu",
|
||||||
use_presharded_weights=use_presharded_weights,
|
use_presharded_weights=use_presharded_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
|
|||||||
Reference in New Issue
Block a user