387 lines
13 KiB
Python
387 lines
13 KiB
Python
# Copyright 2023-2024 SGLang Team
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
# ==============================================================================
|
||
"""Fused operators for activation layers."""
|
||
|
||
import logging
|
||
import math
|
||
from typing import Optional
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from transformers import PretrainedConfig
|
||
|
||
from sglang.srt.custom_op import CustomOp
|
||
from sglang.srt.distributed import (
|
||
divide,
|
||
get_tensor_model_parallel_rank,
|
||
get_tensor_model_parallel_world_size,
|
||
)
|
||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||
from sglang.srt.utils import (
|
||
cpu_has_amx_support,
|
||
is_cpu,
|
||
is_cuda,
|
||
is_hip,
|
||
is_npu,
|
||
is_xpu,
|
||
set_weight_attrs,
|
||
)
|
||
from sglang.utils import resolve_obj_by_qualname
|
||
|
||
_is_cuda = is_cuda()
|
||
_is_npu = is_npu()
|
||
_is_cpu_amx_available = cpu_has_amx_support()
|
||
_is_cpu = is_cpu()
|
||
_is_hip = is_hip()
|
||
_is_xpu = is_xpu()
|
||
|
||
if _is_cuda or _is_xpu:
|
||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||
elif _is_hip:
|
||
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
|
||
|
||
if is_npu():
|
||
import torch_npu
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class SiluAndMul(CustomOp):
|
||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||
d = x.shape[-1] // 2
|
||
return F.silu(x[..., :d]) * x[..., d:]
|
||
|
||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||
d = x.shape[-1] // 2
|
||
output_shape = x.shape[:-1] + (d,)
|
||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||
silu_and_mul(x, out)
|
||
return out
|
||
|
||
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
||
if _is_cpu_amx_available:
|
||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||
return out
|
||
else:
|
||
return self.forward_native(x)
|
||
|
||
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
||
out = torch_npu.npu_swiglu(x)
|
||
return out
|
||
|
||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||
d = x.shape[-1] // 2
|
||
output_shape = x.shape[:-1] + (d,)
|
||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||
silu_and_mul(x, out)
|
||
return out
|
||
|
||
|
||
class GeluAndMul(CustomOp):
|
||
def __init__(self, approximate="tanh"):
|
||
super().__init__()
|
||
self.approximate = approximate
|
||
|
||
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
|
||
d = x.shape[-1] // 2
|
||
output_shape = x.shape[:-1] + (d,)
|
||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||
if self.approximate == "tanh":
|
||
gelu_tanh_and_mul(x, out)
|
||
elif self.approximate == "none":
|
||
gelu_and_mul(x, out)
|
||
else:
|
||
raise RuntimeError("GeluAndMul only support tanh or none")
|
||
return out
|
||
|
||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||
d = x.shape[-1] // 2
|
||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||
|
||
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
||
if _is_cpu_amx_available and self.approximate == "tanh":
|
||
return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
||
elif _is_cpu_amx_available and self.approximate == "none":
|
||
return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
||
else:
|
||
return self.forward_native(x)
|
||
|
||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self._forward_impl(x)
|
||
|
||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self._forward_impl(x)
|
||
|
||
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
||
y_npu, gelu_npu = torch_npu.npu_geglu(
|
||
x,
|
||
dim=-1,
|
||
approximate=1 if self.approximate == "tanh" else 0,
|
||
activate_left=True,
|
||
)
|
||
return y_npu
|
||
|
||
|
||
class NewGELU(CustomOp):
|
||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||
c = math.sqrt(2.0 / math.pi)
|
||
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
|
||
|
||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||
# TODO: Implement the CUDA kernel for NewGELU in sgl-kernel
|
||
return self.forward_native(x)
|
||
|
||
|
||
class ReLU2(nn.Module):
|
||
"""
|
||
Applies the squared Rectified Linear Unit function.
|
||
y = max(0, x)^2
|
||
"""
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = F.relu(x)
|
||
return x * x
|
||
|
||
|
||
class QuickGELU(CustomOp):
|
||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||
return x * torch.sigmoid(1.702 * x)
|
||
|
||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.forward_native(x)
|
||
|
||
def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
|
||
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
||
gelu_quick(x, out)
|
||
return out
|
||
|
||
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
||
return torch_npu.npu_fast_gelu(x)
|
||
|
||
|
||
class XIELU(CustomOp):
|
||
"""
|
||
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
|
||
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
|
||
Otherwise, we emit a single warning and use xIELU Python
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
alpha_p_init: float = 0.8,
|
||
alpha_n_init: float = 0.8,
|
||
beta: float = 0.5,
|
||
eps: float = -1e-6,
|
||
dtype: torch.dtype = torch.bfloat16,
|
||
with_vector_loads: bool = False,
|
||
):
|
||
super().__init__()
|
||
self.alpha_p = nn.Parameter(
|
||
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
|
||
0
|
||
)
|
||
)
|
||
self.alpha_n = nn.Parameter(
|
||
torch.log(
|
||
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
|
||
).unsqueeze(0)
|
||
)
|
||
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
|
||
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
|
||
self.with_vector_loads = with_vector_loads
|
||
# Temporary until xIELU CUDA fully implemented
|
||
self._beta_scalar = float(self.beta.detach().cpu().float().item())
|
||
self._eps_scalar = float(self.eps.detach().cpu().float().item())
|
||
|
||
self._xielu_cuda_obj = None
|
||
try:
|
||
import xielu.ops # noqa: F401
|
||
|
||
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
|
||
msg = "Using experimental xIELU CUDA."
|
||
try:
|
||
from torch._dynamo import allow_in_graph
|
||
|
||
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
|
||
msg += " Enabled torch._dynamo for xIELU CUDA."
|
||
except Exception as err:
|
||
msg += (
|
||
f" Could not enable torch._dynamo for xIELU ({err}) - "
|
||
"this may result in slower performance."
|
||
)
|
||
self._xielu_cuda_fn = self._xielu_cuda
|
||
logger.warning_once(msg)
|
||
except Exception as err:
|
||
pass
|
||
# logger.warning_once(
|
||
# "CUDA-fused xIELU not available (%s) –"
|
||
# " falling back to a Python version.\n"
|
||
# "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
|
||
# str(err),
|
||
# )
|
||
|
||
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
|
||
alpha_p = nn.functional.softplus(self.alpha_p)
|
||
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
|
||
return torch.where(
|
||
x > 0,
|
||
alpha_p * x * x + self.beta * x,
|
||
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
|
||
)
|
||
|
||
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""Firewall function to prevent torch.compile from seeing .item()"""
|
||
assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
|
||
original_shape = x.shape
|
||
# CUDA kernel expects 3D tensors, reshape if needed
|
||
while x.dim() < 3:
|
||
x = x.unsqueeze(0)
|
||
if x.dim() > 3:
|
||
x = x.view(-1, 1, x.size(-1))
|
||
if original_shape != x.shape:
|
||
logger.warning_once(
|
||
"Warning: xIELU input tensor expects 3 dimensions"
|
||
" but got (shape: %s). Reshaping to (shape: %s).\n"
|
||
"Note: For SGLang this may be expected if sending"
|
||
"[B*S,D] instead of [B,S,D].",
|
||
original_shape,
|
||
x.shape,
|
||
)
|
||
result = self._xielu_cuda_obj.forward(
|
||
x,
|
||
self.alpha_p,
|
||
self.alpha_n,
|
||
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
|
||
self._beta_scalar,
|
||
self._eps_scalar,
|
||
self.with_vector_loads,
|
||
)
|
||
return result.view(original_shape)
|
||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
if self._xielu_cuda_obj is not None and input.is_cuda:
|
||
if not torch._dynamo.is_compiling():
|
||
return self._xielu_cuda_fn(input)
|
||
else:
|
||
logger.warning_once(
|
||
"torch._dynamo is compiling, using Python version of xIELU."
|
||
)
|
||
return self._xielu_python(input)
|
||
|
||
|
||
class ScaledActivation(nn.Module):
|
||
"""An activation function with post-scale parameters.
|
||
|
||
This is used for some quantization methods like AWQ.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
act_module: nn.Module,
|
||
intermediate_size: int,
|
||
input_is_parallel: bool = True,
|
||
params_dtype: Optional[torch.dtype] = None,
|
||
):
|
||
super().__init__()
|
||
self.act = act_module
|
||
self.input_is_parallel = input_is_parallel
|
||
if input_is_parallel:
|
||
tp_size = get_tensor_model_parallel_world_size()
|
||
intermediate_size_per_partition = divide(intermediate_size, tp_size)
|
||
else:
|
||
intermediate_size_per_partition = intermediate_size
|
||
if params_dtype is None:
|
||
params_dtype = torch.get_default_dtype()
|
||
self.scales = nn.Parameter(
|
||
torch.empty(intermediate_size_per_partition, dtype=params_dtype)
|
||
)
|
||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.act(x) / self.scales
|
||
|
||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||
param_data = param.data
|
||
if self.input_is_parallel:
|
||
tp_rank = get_tensor_model_parallel_rank()
|
||
shard_size = param_data.shape[0]
|
||
start_idx = tp_rank * shard_size
|
||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||
assert param_data.shape == loaded_weight.shape
|
||
param_data.copy_(loaded_weight)
|
||
|
||
|
||
_ACTIVATION_REGISTRY = {
|
||
"gelu": nn.GELU(),
|
||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||
"gelu_new": NewGELU(),
|
||
"relu2": ReLU2(),
|
||
"xielu": XIELU(),
|
||
}
|
||
|
||
|
||
def get_act_fn(
|
||
act_fn_name: str,
|
||
quant_config: Optional[QuantizationConfig] = None,
|
||
intermediate_size: Optional[int] = None,
|
||
input_is_parallel: bool = True,
|
||
params_dtype: Optional[torch.dtype] = None,
|
||
) -> nn.Module:
|
||
"""Get an activation function by name."""
|
||
act_fn_name = act_fn_name.lower()
|
||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
|
||
|
||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||
if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
|
||
if intermediate_size is None:
|
||
raise ValueError(
|
||
"intermediate_size must be specified for scaled "
|
||
"activation functions."
|
||
)
|
||
return ScaledActivation(
|
||
act_fn, intermediate_size, input_is_parallel, params_dtype
|
||
)
|
||
return act_fn
|
||
|
||
|
||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||
if (
|
||
hasattr(config, "sbert_ce_default_activation_function")
|
||
and config.sbert_ce_default_activation_function is not None
|
||
):
|
||
|
||
function_name = config.sbert_ce_default_activation_function
|
||
assert function_name.startswith("torch.nn.modules."), (
|
||
"Loading of activation functions is restricted to "
|
||
"torch.nn.modules for security reasons"
|
||
)
|
||
return resolve_obj_by_qualname(function_name)()
|
||
else:
|
||
# adapt bge-reranker
|
||
return nn.Identity()
|
||
|
||
|
||
if not (
|
||
_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
|
||
):
|
||
logger.info(
|
||
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
||
)
|
||
from vllm.model_executor.layers.activation import ( # noqa: F401
|
||
GeluAndMul,
|
||
SiluAndMul,
|
||
)
|