feat: replace get_act_fn for gpt_bigcode (#1231)
This commit is contained in:
@@ -13,10 +13,20 @@ limitations under the License.
|
|||||||
|
|
||||||
"""Fused operators for activation layers."""
|
"""Fused operators for activation layers."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
|
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
|
||||||
|
from vllm.distributed import (
|
||||||
|
divide,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(CustomOp):
|
class SiluAndMul(CustomOp):
|
||||||
@@ -53,3 +63,76 @@ class GeluAndMul(CustomOp):
|
|||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
gelu_tanh_and_mul(x, out)
|
gelu_tanh_and_mul(x, out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledActivation(nn.Module):
|
||||||
|
"""An activation function with post-scale parameters.
|
||||||
|
|
||||||
|
This is used for some quantization methods like AWQ.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
act_module: nn.Module,
|
||||||
|
intermediate_size: int,
|
||||||
|
input_is_parallel: bool = True,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.act = act_module
|
||||||
|
self.input_is_parallel = input_is_parallel
|
||||||
|
if input_is_parallel:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
intermediate_size_per_partition = divide(intermediate_size, tp_size)
|
||||||
|
else:
|
||||||
|
intermediate_size_per_partition = intermediate_size
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.scales = nn.Parameter(
|
||||||
|
torch.empty(intermediate_size_per_partition, dtype=params_dtype)
|
||||||
|
)
|
||||||
|
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.act(x) / self.scales
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
|
param_data = param.data
|
||||||
|
if self.input_is_parallel:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = param_data.shape[0]
|
||||||
|
start_idx = tp_rank * shard_size
|
||||||
|
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
_ACTIVATION_REGISTRY = {
|
||||||
|
"gelu": nn.GELU(),
|
||||||
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_act_fn(
|
||||||
|
act_fn_name: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
intermediate_size: Optional[int] = None,
|
||||||
|
input_is_parallel: bool = True,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""Get an activation function by name."""
|
||||||
|
act_fn_name = act_fn_name.lower()
|
||||||
|
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||||
|
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
|
||||||
|
|
||||||
|
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||||
|
if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
|
||||||
|
if intermediate_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
"intermediate_size must be specified for scaled "
|
||||||
|
"activation functions."
|
||||||
|
)
|
||||||
|
return ScaledActivation(
|
||||||
|
act_fn, intermediate_size, input_is_parallel, params_dtype
|
||||||
|
)
|
||||||
|
return act_fn
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from torch import nn
|
|||||||
from transformers import GPTBigCodeConfig
|
from transformers import GPTBigCodeConfig
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -33,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
from sglang.srt.layers.activation import get_act_fn
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
from sglang.srt.layers.sampler import Sampler
|
||||||
|
|||||||
Reference in New Issue
Block a user