init
This commit is contained in:
173
vllm/model_executor/layers/activation.py
Normal file
173
vllm/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Custom activation functions."""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class GeluAndMul(nn.Module):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def __init__(self, approximate: str = "none"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if self.approximate == "none":
|
||||
ops.gelu_and_mul(out, x)
|
||||
elif self.approximate == "tanh":
|
||||
ops.gelu_tanh_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
c = math.sqrt(2.0 / math.pi)
|
||||
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_new(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class FastGELU(nn.Module):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||
(1.0 + 0.044715 * x * x)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_fast(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
This is used for some quantization methods like AWQ.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.input_is_parallel = input_is_parallel
|
||||
if input_is_parallel:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size_per_partition = divide(intermediate_size,
|
||||
tp_size)
|
||||
else:
|
||||
intermediate_size_per_partition = intermediate_size
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.act(x) / self.scales
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
if self.input_is_parallel:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = param_data.shape[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_fast": FastGELU(),
|
||||
"gelu_new": NewGELU(),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
|
||||
def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
Reference in New Issue
Block a user