225 lines
9.4 KiB
Python
225 lines
9.4 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from vllm.distributed import (
|
||
|
|
get_parallel_world_size_with_group,
|
||
|
|
get_parallel_rank_with_group,
|
||
|
|
)
|
||
|
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.lora.layers import BaseLayerWithLoRA
|
||
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
||
|
|
QuantizationConfig)
|
||
|
|
from vllm.model_executor.layers.linear import (
|
||
|
|
MergedColumnParallelLinear,
|
||
|
|
ColumnParallelLinear,
|
||
|
|
RowParallelLinear
|
||
|
|
)
|
||
|
|
|
||
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
||
|
|
from vllm_mlu.mlu_hijack_utils import set_is_gated
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
class FeedForward(torch.nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
hidden_size: int,
|
||
|
|
intermediate_size: int,
|
||
|
|
hidden_act: str,
|
||
|
|
up_proj_name: str,
|
||
|
|
is_gated: bool,
|
||
|
|
down_proj_name: str,
|
||
|
|
bias: bool,
|
||
|
|
quant_config: QuantizationConfig | None = None,
|
||
|
|
skip_bias_add: bool = False,
|
||
|
|
reduce_results: bool = True,
|
||
|
|
prefix: str = "",
|
||
|
|
tp_group: Any = None,
|
||
|
|
keep_full_weights: bool = False,
|
||
|
|
) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self.hidden_size = hidden_size
|
||
|
|
self.hidden_act = hidden_act
|
||
|
|
self.is_gated = is_gated
|
||
|
|
self.bias = bias
|
||
|
|
self.up_proj_name = up_proj_name
|
||
|
|
self.down_proj_name = down_proj_name
|
||
|
|
self.quant_config = quant_config
|
||
|
|
self.is_initialized = False
|
||
|
|
self.skip_bias_add = skip_bias_add
|
||
|
|
self.reduce_results = reduce_results
|
||
|
|
self.use_bt_ffn = True
|
||
|
|
set_is_gated(self.is_gated)
|
||
|
|
|
||
|
|
# modify tp_size, tp_rank and tp_group when enable data parallel
|
||
|
|
self.tp_size = get_parallel_world_size_with_group(tp_group)
|
||
|
|
self.tp_rank = get_parallel_rank_with_group(tp_group)
|
||
|
|
self.tp_group = tp_group
|
||
|
|
self.keep_full_weights = keep_full_weights
|
||
|
|
if self.keep_full_weights:
|
||
|
|
self.tp_size = 1
|
||
|
|
self.tp_rank = 0
|
||
|
|
self.tp_group = None
|
||
|
|
|
||
|
|
# up_proj with gate or not
|
||
|
|
if self.is_gated:
|
||
|
|
up_proj = MergedColumnParallelLinear(hidden_size,
|
||
|
|
[intermediate_size] * 2,
|
||
|
|
bias=bias,
|
||
|
|
quant_config=quant_config,
|
||
|
|
prefix=f"{prefix}.{up_proj_name}",
|
||
|
|
tp_group=self.tp_group,
|
||
|
|
keep_full_weights=keep_full_weights)
|
||
|
|
else:
|
||
|
|
up_proj = ColumnParallelLinear(hidden_size,
|
||
|
|
intermediate_size,
|
||
|
|
bias=bias,
|
||
|
|
skip_bias_add=skip_bias_add,
|
||
|
|
quant_config=quant_config,
|
||
|
|
prefix=f"{prefix}.{up_proj_name}",
|
||
|
|
tp_group=self.tp_group,
|
||
|
|
keep_full_weights=keep_full_weights)
|
||
|
|
self.register_module(up_proj_name, up_proj)
|
||
|
|
|
||
|
|
# down_proj
|
||
|
|
down_proj = RowParallelLinear(intermediate_size,
|
||
|
|
hidden_size,
|
||
|
|
bias=bias,
|
||
|
|
skip_bias_add=skip_bias_add,
|
||
|
|
reduce_results=reduce_results,
|
||
|
|
quant_config=quant_config,
|
||
|
|
prefix=f"{prefix}.{down_proj_name}",
|
||
|
|
tp_group=self.tp_group,
|
||
|
|
keep_full_weights=keep_full_weights)
|
||
|
|
|
||
|
|
self.register_module(down_proj_name, down_proj)
|
||
|
|
|
||
|
|
def prepare_weight(self):
|
||
|
|
if not self.is_initialized:
|
||
|
|
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
|
||
|
|
self.alpha = 1.0
|
||
|
|
self.beta = 0.0
|
||
|
|
# place it here to avoid the overhead of calling it in the forward pass
|
||
|
|
self.is_initialized = True
|
||
|
|
|
||
|
|
def _forward(self, hidden_states):
|
||
|
|
self.prepare_weight()
|
||
|
|
up_proj = getattr(self, self.up_proj_name)
|
||
|
|
down_proj = getattr(self, self.down_proj_name)
|
||
|
|
act_dict = {
|
||
|
|
"relu": F.relu,
|
||
|
|
"gelu": F.gelu,
|
||
|
|
"silu": F.silu,
|
||
|
|
}
|
||
|
|
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
|
||
|
|
if self.is_gated:
|
||
|
|
d = fc1.shape[-1] // 2
|
||
|
|
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
|
||
|
|
else:
|
||
|
|
fc1 = act_dict[self.hidden_act](fc1)
|
||
|
|
fc2 = F.linear(fc1, down_proj.weight, bias=None)
|
||
|
|
fc2 = tensor_model_parallel_all_reduce(fc2)
|
||
|
|
if not self.skip_bias_add:
|
||
|
|
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
|
||
|
|
return fc2
|
||
|
|
|
||
|
|
def forward_naive(
|
||
|
|
self,
|
||
|
|
hidden_states,
|
||
|
|
residual: torch.Tensor | None = None,
|
||
|
|
smooth_quant_scale: torch.Tensor | None = None
|
||
|
|
):
|
||
|
|
'''
|
||
|
|
used by quant_tools
|
||
|
|
'''
|
||
|
|
assert self.quant_config is None, "ffn naive forward dosen't support quantization"
|
||
|
|
assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale"
|
||
|
|
|
||
|
|
up_proj = getattr(self, self.up_proj_name)
|
||
|
|
down_proj = getattr(self, self.down_proj_name)
|
||
|
|
residual_ = None if self.tp_rank > 0 else residual
|
||
|
|
fc1, bias = up_proj(hidden_states)
|
||
|
|
if bias is not None:
|
||
|
|
fc1 += bias
|
||
|
|
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||
|
|
out, bias = down_proj(fc1, residual=residual_)
|
||
|
|
|
||
|
|
if self.skip_bias_add:
|
||
|
|
return out, bias
|
||
|
|
return out
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
hidden_states,
|
||
|
|
residual: torch.Tensor | None = None,
|
||
|
|
smooth_quant_scale: torch.Tensor | None = None,
|
||
|
|
use_tp_weight: bool = False,
|
||
|
|
output: torch.Tensor | None = None,
|
||
|
|
):
|
||
|
|
self.prepare_weight()
|
||
|
|
|
||
|
|
if self.use_bt_ffn is False:
|
||
|
|
return self.forward_naive(hidden_states, residual, None)
|
||
|
|
|
||
|
|
up_proj = getattr(self, self.up_proj_name)
|
||
|
|
down_proj = getattr(self, self.down_proj_name)
|
||
|
|
residual_ = None if self.tp_rank > 0 else residual
|
||
|
|
if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA)
|
||
|
|
and not isinstance(down_proj, BaseLayerWithLoRA)):
|
||
|
|
# The matmul formula is the following:
|
||
|
|
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
|
||
|
|
# output = active(mul_out)
|
||
|
|
# Notes: We cannot use the activation function in matmul because it does not support gated operation
|
||
|
|
# we might support its in tmo matmul in the future
|
||
|
|
up_proj_weight = up_proj.weight
|
||
|
|
down_proj_weight = down_proj.weight
|
||
|
|
if self.keep_full_weights and use_tp_weight:
|
||
|
|
up_proj_weight = up_proj.tp_weight
|
||
|
|
down_proj_weight = down_proj.tp_weight
|
||
|
|
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias,
|
||
|
|
None, 'none', self.alpha, self.beta)
|
||
|
|
act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype)
|
||
|
|
beta = 0.0
|
||
|
|
if residual_ is not None:
|
||
|
|
beta = 1.0
|
||
|
|
residual_ = residual_.view(-1, residual_.shape[-1])
|
||
|
|
out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta)
|
||
|
|
# bias if existed need to add after second matmul according to the original design of vllm
|
||
|
|
if self.reduce_results:
|
||
|
|
out = tensor_model_parallel_all_reduce(out_, self.tp_group)
|
||
|
|
else:
|
||
|
|
out = out_
|
||
|
|
# do the bias add if needed
|
||
|
|
if not self.skip_bias_add:
|
||
|
|
out = out + down_proj.bias if down_proj.bias is not None else out
|
||
|
|
else:
|
||
|
|
return out, down_proj.bias
|
||
|
|
else:
|
||
|
|
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight)
|
||
|
|
if bias is not None:
|
||
|
|
fc1 += bias
|
||
|
|
input_scale= None
|
||
|
|
if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and
|
||
|
|
self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8):
|
||
|
|
down_proj.quant_method.skip_quant_input = True
|
||
|
|
down_proj_smooth = down_proj.smooth
|
||
|
|
if self.keep_full_weights and use_tp_weight:
|
||
|
|
assert down_proj.tp_smooth is not None, "tp_smooth is not initialized"
|
||
|
|
down_proj_smooth = down_proj.tp_smooth
|
||
|
|
fc1, input_scale = mlu_ops.per_token_smooth_quantize(
|
||
|
|
fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated)
|
||
|
|
else:
|
||
|
|
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||
|
|
out, bias = down_proj(
|
||
|
|
fc1, residual=residual_, smooth_quant_scale=input_scale,
|
||
|
|
use_tp_weight=use_tp_weight, output=output)
|
||
|
|
|
||
|
|
if self.skip_bias_add:
|
||
|
|
return out, bias
|
||
|
|
return out
|