142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
"""MoE activation function enum and utilities."""
|
||
|
||
from enum import Enum
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
from vllm._custom_ops import silu_and_mul, gelu_and_mul, swigluoai_and_mul
|
||
|
||
|
||
class MoEActivation(Enum):
|
||
"""Activation functions for MoE layers."""
|
||
|
||
# Gated activations (gate * activation(up)) expect input of shape [..., 2*d]
|
||
# and produce output of shape [..., d]
|
||
SILU = "silu"
|
||
GELU = "gelu"
|
||
RELU2 = "relu2"
|
||
SWIGLUOAI = "swigluoai"
|
||
SWIGLUSTEP = "swiglustep"
|
||
|
||
# Non-gated activations (no mul with gate) expect input of shape [..., d]
|
||
# and produce output of shape [..., d].
|
||
# NOTE: Non-gated activations require the "_no_mul" suffix to be present.
|
||
SILU_NO_MUL = "silu_no_mul"
|
||
GELU_NO_MUL = "gelu_no_mul"
|
||
RELU2_NO_MUL = "relu2_no_mul"
|
||
|
||
@property
|
||
def is_gated(self) -> bool:
|
||
"""Returns True if activation expects gate*activation(up) pattern.
|
||
|
||
Gated activations expect input tensor with 2x the output size,
|
||
where the first half is the gate and second half is the up projection.
|
||
"""
|
||
return not self.value.endswith("_no_mul")
|
||
|
||
@property
|
||
def custom_op_name(self) -> str:
|
||
"""Maps to the CustomOp name of activations
|
||
in vllm/model_executor/layers/activation.py."""
|
||
return _CUSTOM_OP_NAMES[self]
|
||
|
||
def without_mul(self) -> "MoEActivation":
|
||
"""Get the non-gated variant of this activation.
|
||
|
||
For activations that have a _no_mul variant, returns that variant.
|
||
For activations without a _no_mul variant (or already _no_mul),
|
||
returns self.
|
||
"""
|
||
return _WITHOUT_MUL.get(self, self)
|
||
|
||
@classmethod
|
||
def from_str(cls, s: str) -> "MoEActivation":
|
||
"""Parse from string for backward compatibility."""
|
||
for member in cls:
|
||
if member.value == s:
|
||
return member
|
||
valid = [m.value for m in cls]
|
||
raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
|
||
|
||
|
||
# Module-level lookup tables used by MoEActivation functions.
|
||
_CUSTOM_OP_NAMES: dict[MoEActivation, str] = {
|
||
MoEActivation.SILU: "silu_and_mul",
|
||
MoEActivation.GELU: "gelu_and_mul",
|
||
MoEActivation.SWIGLUOAI: "swigluoai_and_mul",
|
||
MoEActivation.SWIGLUSTEP: "swiglustep_and_mul",
|
||
MoEActivation.RELU2: "relu2",
|
||
MoEActivation.SILU_NO_MUL: "silu_and_mul",
|
||
MoEActivation.GELU_NO_MUL: "gelu_and_mul",
|
||
MoEActivation.RELU2_NO_MUL: "relu2",
|
||
}
|
||
|
||
_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = {
|
||
MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
|
||
MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
|
||
MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
|
||
}
|
||
|
||
|
||
def activation_without_mul(activation: str) -> str:
|
||
"""Get the non-gated variant of an activation function.
|
||
|
||
Args:
|
||
activation: The activation function name (e.g., "silu", "gelu")
|
||
|
||
Returns:
|
||
The non-gated activation name (e.g., "silu_no_mul", "gelu_no_mul")
|
||
"""
|
||
return MoEActivation.from_str(activation).without_mul().value
|
||
|
||
|
||
def apply_moe_activation(
|
||
activation: MoEActivation,
|
||
output: torch.Tensor,
|
||
input: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
"""Apply MoE activation function."""
|
||
assert input.dim() == 2, "Input must be 2D"
|
||
assert output.dim() == 2, "Output must be 2D"
|
||
if activation.is_gated:
|
||
assert output.size(-1) * 2 == input.size(-1), (
|
||
f"{activation.value} expects 2x ratio: "
|
||
f"{output.size(-1) * 2} vs {input.size(-1)}"
|
||
)
|
||
else:
|
||
assert output.size(-1) == input.size(-1), (
|
||
f"{activation.value} expects equal sizes: "
|
||
f"{output.size(-1)} vs {input.size(-1)}"
|
||
)
|
||
|
||
# Activations with gated multiplication (gate × activation(up))
|
||
if activation == MoEActivation.SILU:
|
||
# torch.ops._C.silu_and_mul(output, input)
|
||
silu_and_mul(output, input)
|
||
elif activation == MoEActivation.GELU:
|
||
# torch.ops._C.gelu_and_mul(output, input)
|
||
gelu_and_mul(output, input)
|
||
elif activation == MoEActivation.SWIGLUOAI:
|
||
# torch.ops._C.swigluoai_and_mul(output, input)
|
||
swigluoai_and_mul(output, input)
|
||
elif activation == MoEActivation.SWIGLUSTEP:
|
||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||
|
||
swiglustep_and_mul_triton(output, input)
|
||
|
||
# Activations without gated multiplication
|
||
elif activation == MoEActivation.SILU_NO_MUL:
|
||
output.copy_(F.silu(input))
|
||
elif activation == MoEActivation.GELU_NO_MUL:
|
||
output.copy_(F.gelu(input))
|
||
elif activation == MoEActivation.RELU2_NO_MUL:
|
||
F.relu(input, inplace=True)
|
||
torch.square(input, out=output)
|
||
else:
|
||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||
|
||
return output
|