Add activation parameters to fused_moe (#3170)
This commit is contained in:
@@ -763,8 +763,8 @@ class Fp8MoEMethod:
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
@@ -785,6 +785,8 @@ class Fp8MoEMethod:
|
||||
import ater
|
||||
from ater.fused_moe import fused_experts_ck
|
||||
|
||||
assert activation == "silu", f"{activation=} is not supported."
|
||||
|
||||
return fused_experts_ck(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -815,6 +817,7 @@ class Fp8MoEMethod:
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
|
||||
Reference in New Issue
Block a user