2026-04-18 10:56:22 +08:00
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
"""MoE activation function enum and utilities."""
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn.functional as F
|
2026-04-29 19:38:22 +08:00
|
|
|
|
from vllm import _custom_ops as ops
|
2026-04-18 10:56:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MoEActivation(Enum):
|
|
|
|
|
|
"""Activation functions for MoE layers."""
|
|
|
|
|
|
|
|
|
|
|
|
# Gated activations (gate * activation(up)) expect input of shape [..., 2*d]
|
|
|
|
|
|
# and produce output of shape [..., d]
|
|
|
|
|
|
SILU = "silu"
|
|
|
|
|
|
GELU = "gelu"
|
|
|
|
|
|
RELU2 = "relu2"
|
|
|
|
|
|
SWIGLUOAI = "swigluoai"
|
|
|
|
|
|
SWIGLUSTEP = "swiglustep"
|
|
|
|
|
|
|
|
|
|
|
|
# Non-gated activations (no mul with gate) expect input of shape [..., d]
|
|
|
|
|
|
# and produce output of shape [..., d].
|
|
|
|
|
|
# NOTE: Non-gated activations require the "_no_mul" suffix to be present.
|
|
|
|
|
|
SILU_NO_MUL = "silu_no_mul"
|
|
|
|
|
|
GELU_NO_MUL = "gelu_no_mul"
|
|
|
|
|
|
RELU2_NO_MUL = "relu2_no_mul"
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def is_gated(self) -> bool:
|
|
|
|
|
|
"""Returns True if activation expects gate*activation(up) pattern.
|
|
|
|
|
|
|
|
|
|
|
|
Gated activations expect input tensor with 2x the output size,
|
|
|
|
|
|
where the first half is the gate and second half is the up projection.
|
|
|
|
|
|
"""
|
|
|
|
|
|
return not self.value.endswith("_no_mul")
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def custom_op_name(self) -> str:
|
|
|
|
|
|
"""Maps to the CustomOp name of activations
|
|
|
|
|
|
in vllm/model_executor/layers/activation.py."""
|
|
|
|
|
|
return _CUSTOM_OP_NAMES[self]
|
|
|
|
|
|
|
|
|
|
|
|
def without_mul(self) -> "MoEActivation":
|
|
|
|
|
|
"""Get the non-gated variant of this activation.
|
|
|
|
|
|
|
|
|
|
|
|
For activations that have a _no_mul variant, returns that variant.
|
|
|
|
|
|
For activations without a _no_mul variant (or already _no_mul),
|
|
|
|
|
|
returns self.
|
|
|
|
|
|
"""
|
|
|
|
|
|
return _WITHOUT_MUL.get(self, self)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def from_str(cls, s: str) -> "MoEActivation":
|
|
|
|
|
|
"""Parse from string for backward compatibility."""
|
|
|
|
|
|
for member in cls:
|
|
|
|
|
|
if member.value == s:
|
|
|
|
|
|
return member
|
|
|
|
|
|
valid = [m.value for m in cls]
|
|
|
|
|
|
raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Module-level lookup tables used by MoEActivation functions.
|
|
|
|
|
|
_CUSTOM_OP_NAMES: dict[MoEActivation, str] = {
|
|
|
|
|
|
MoEActivation.SILU: "silu_and_mul",
|
|
|
|
|
|
MoEActivation.GELU: "gelu_and_mul",
|
|
|
|
|
|
MoEActivation.SWIGLUOAI: "swigluoai_and_mul",
|
|
|
|
|
|
MoEActivation.SWIGLUSTEP: "swiglustep_and_mul",
|
|
|
|
|
|
MoEActivation.RELU2: "relu2",
|
|
|
|
|
|
MoEActivation.SILU_NO_MUL: "silu_and_mul",
|
|
|
|
|
|
MoEActivation.GELU_NO_MUL: "gelu_and_mul",
|
|
|
|
|
|
MoEActivation.RELU2_NO_MUL: "relu2",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = {
|
|
|
|
|
|
MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
|
|
|
|
|
|
MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
|
|
|
|
|
|
MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def activation_without_mul(activation: str) -> str:
|
|
|
|
|
|
"""Get the non-gated variant of an activation function.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
activation: The activation function name (e.g., "silu", "gelu")
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
The non-gated activation name (e.g., "silu_no_mul", "gelu_no_mul")
|
|
|
|
|
|
"""
|
|
|
|
|
|
return MoEActivation.from_str(activation).without_mul().value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_moe_activation(
|
|
|
|
|
|
activation: MoEActivation,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
"""Apply MoE activation function."""
|
|
|
|
|
|
assert input.dim() == 2, "Input must be 2D"
|
|
|
|
|
|
assert output.dim() == 2, "Output must be 2D"
|
|
|
|
|
|
if activation.is_gated:
|
|
|
|
|
|
assert output.size(-1) * 2 == input.size(-1), (
|
|
|
|
|
|
f"{activation.value} expects 2x ratio: "
|
|
|
|
|
|
f"{output.size(-1) * 2} vs {input.size(-1)}"
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
assert output.size(-1) == input.size(-1), (
|
|
|
|
|
|
f"{activation.value} expects equal sizes: "
|
|
|
|
|
|
f"{output.size(-1)} vs {input.size(-1)}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Activations with gated multiplication (gate × activation(up))
|
|
|
|
|
|
if activation == MoEActivation.SILU:
|
2026-04-29 19:38:22 +08:00
|
|
|
|
ops.silu_and_mul(output, input)
|
2026-04-18 10:56:22 +08:00
|
|
|
|
elif activation == MoEActivation.GELU:
|
2026-04-29 19:38:22 +08:00
|
|
|
|
ops.gelu_and_mul(output, input)
|
2026-04-18 10:56:22 +08:00
|
|
|
|
elif activation == MoEActivation.SWIGLUOAI:
|
2026-04-29 19:38:22 +08:00
|
|
|
|
ops.swigluoai_and_mul(output, input)
|
2026-04-18 10:56:22 +08:00
|
|
|
|
elif activation == MoEActivation.SWIGLUSTEP:
|
|
|
|
|
|
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
|
|
|
|
|
|
|
|
|
|
|
swiglustep_and_mul_triton(output, input)
|
|
|
|
|
|
|
|
|
|
|
|
# Activations without gated multiplication
|
|
|
|
|
|
elif activation == MoEActivation.SILU_NO_MUL:
|
|
|
|
|
|
output.copy_(F.silu(input))
|
|
|
|
|
|
elif activation == MoEActivation.GELU_NO_MUL:
|
|
|
|
|
|
output.copy_(F.gelu(input))
|
|
|
|
|
|
elif activation == MoEActivation.RELU2_NO_MUL:
|
|
|
|
|
|
F.relu(input, inplace=True)
|
|
|
|
|
|
torch.square(input, out=output)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
|
|
|
|
|
|
|
|
|
|
|
return output
|