Files
enginex-mlu590-vllm/vllm_mlu/model_executor/layers/feed_forward.py
2026-04-24 09:58:03 +08:00

225 lines
9.4 KiB
Python
Executable File

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import torch.nn.functional as F
from typing import Any
from vllm.distributed import (
get_parallel_world_size_with_group,
get_parallel_rank_with_group,
)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
ColumnParallelLinear,
RowParallelLinear
)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import set_is_gated
logger = init_logger(__name__)
class FeedForward(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
bias: bool,
quant_config: QuantizationConfig | None = None,
skip_bias_add: bool = False,
reduce_results: bool = True,
prefix: str = "",
tp_group: Any = None,
keep_full_weights: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.is_gated = is_gated
self.bias = bias
self.up_proj_name = up_proj_name
self.down_proj_name = down_proj_name
self.quant_config = quant_config
self.is_initialized = False
self.skip_bias_add = skip_bias_add
self.reduce_results = reduce_results
self.use_bt_ffn = True
set_is_gated(self.is_gated)
# modify tp_size, tp_rank and tp_group when enable data parallel
self.tp_size = get_parallel_world_size_with_group(tp_group)
self.tp_rank = get_parallel_rank_with_group(tp_group)
self.tp_group = tp_group
self.keep_full_weights = keep_full_weights
if self.keep_full_weights:
self.tp_size = 1
self.tp_rank = 0
self.tp_group = None
# up_proj with gate or not
if self.is_gated:
up_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
else:
up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=bias,
skip_bias_add=skip_bias_add,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
self.register_module(up_proj_name, up_proj)
# down_proj
down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
skip_bias_add=skip_bias_add,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=f"{prefix}.{down_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
self.register_module(down_proj_name, down_proj)
def prepare_weight(self):
if not self.is_initialized:
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
self.alpha = 1.0
self.beta = 0.0
# place it here to avoid the overhead of calling it in the forward pass
self.is_initialized = True
def _forward(self, hidden_states):
self.prepare_weight()
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
act_dict = {
"relu": F.relu,
"gelu": F.gelu,
"silu": F.silu,
}
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
if self.is_gated:
d = fc1.shape[-1] // 2
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
else:
fc1 = act_dict[self.hidden_act](fc1)
fc2 = F.linear(fc1, down_proj.weight, bias=None)
fc2 = tensor_model_parallel_all_reduce(fc2)
if not self.skip_bias_add:
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
return fc2
def forward_naive(
self,
hidden_states,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None
):
'''
used by quant_tools
'''
assert self.quant_config is None, "ffn naive forward dosen't support quantization"
assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale"
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
fc1, bias = up_proj(hidden_states)
if bias is not None:
fc1 += bias
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(fc1, residual=residual_)
if self.skip_bias_add:
return out, bias
return out
def forward(
self,
hidden_states,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
):
self.prepare_weight()
if self.use_bt_ffn is False:
return self.forward_naive(hidden_states, residual, None)
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA)
and not isinstance(down_proj, BaseLayerWithLoRA)):
# The matmul formula is the following:
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
# output = active(mul_out)
# Notes: We cannot use the activation function in matmul because it does not support gated operation
# we might support its in tmo matmul in the future
up_proj_weight = up_proj.weight
down_proj_weight = down_proj.weight
if self.keep_full_weights and use_tp_weight:
up_proj_weight = up_proj.tp_weight
down_proj_weight = down_proj.tp_weight
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias,
None, 'none', self.alpha, self.beta)
act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype)
beta = 0.0
if residual_ is not None:
beta = 1.0
residual_ = residual_.view(-1, residual_.shape[-1])
out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta)
# bias if existed need to add after second matmul according to the original design of vllm
if self.reduce_results:
out = tensor_model_parallel_all_reduce(out_, self.tp_group)
else:
out = out_
# do the bias add if needed
if not self.skip_bias_add:
out = out + down_proj.bias if down_proj.bias is not None else out
else:
return out, down_proj.bias
else:
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight)
if bias is not None:
fc1 += bias
input_scale= None
if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and
self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8):
down_proj.quant_method.skip_quant_input = True
down_proj_smooth = down_proj.smooth
if self.keep_full_weights and use_tp_weight:
assert down_proj.tp_smooth is not None, "tp_smooth is not initialized"
down_proj_smooth = down_proj.tp_smooth
fc1, input_scale = mlu_ops.per_token_smooth_quantize(
fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated)
else:
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(
fc1, residual=residual_, smooth_quant_scale=input_scale,
use_tp_weight=use_tp_weight, output=output)
if self.skip_bias_add:
return out, bias
return out