# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project import torch import torch.nn.functional as F from typing import Any from vllm.distributed import ( get_parallel_world_size_with_group, get_parallel_rank_with_group, ) from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, ColumnParallelLinear, RowParallelLinear ) from vllm_mlu import _mlu_ops as mlu_ops from vllm_mlu.mlu_hijack_utils import set_is_gated logger = init_logger(__name__) class FeedForward(torch.nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, up_proj_name: str, is_gated: bool, down_proj_name: str, bias: bool, quant_config: QuantizationConfig | None = None, skip_bias_add: bool = False, reduce_results: bool = True, prefix: str = "", tp_group: Any = None, keep_full_weights: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size self.hidden_act = hidden_act self.is_gated = is_gated self.bias = bias self.up_proj_name = up_proj_name self.down_proj_name = down_proj_name self.quant_config = quant_config self.is_initialized = False self.skip_bias_add = skip_bias_add self.reduce_results = reduce_results self.use_bt_ffn = True set_is_gated(self.is_gated) # modify tp_size, tp_rank and tp_group when enable data parallel self.tp_size = get_parallel_world_size_with_group(tp_group) self.tp_rank = get_parallel_rank_with_group(tp_group) self.tp_group = tp_group self.keep_full_weights = keep_full_weights if self.keep_full_weights: self.tp_size = 1 self.tp_rank = 0 self.tp_group = None # up_proj with gate or not if self.is_gated: up_proj = MergedColumnParallelLinear(hidden_size, [intermediate_size] * 2, bias=bias, quant_config=quant_config, prefix=f"{prefix}.{up_proj_name}", tp_group=self.tp_group, keep_full_weights=keep_full_weights) else: up_proj = ColumnParallelLinear(hidden_size, intermediate_size, bias=bias, skip_bias_add=skip_bias_add, quant_config=quant_config, prefix=f"{prefix}.{up_proj_name}", tp_group=self.tp_group, keep_full_weights=keep_full_weights) self.register_module(up_proj_name, up_proj) # down_proj down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=bias, skip_bias_add=skip_bias_add, reduce_results=reduce_results, quant_config=quant_config, prefix=f"{prefix}.{down_proj_name}", tp_group=self.tp_group, keep_full_weights=keep_full_weights) self.register_module(down_proj_name, down_proj) def prepare_weight(self): if not self.is_initialized: # alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now self.alpha = 1.0 self.beta = 0.0 # place it here to avoid the overhead of calling it in the forward pass self.is_initialized = True def _forward(self, hidden_states): self.prepare_weight() up_proj = getattr(self, self.up_proj_name) down_proj = getattr(self, self.down_proj_name) act_dict = { "relu": F.relu, "gelu": F.gelu, "silu": F.silu, } fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias) if self.is_gated: d = fc1.shape[-1] // 2 fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:] else: fc1 = act_dict[self.hidden_act](fc1) fc2 = F.linear(fc1, down_proj.weight, bias=None) fc2 = tensor_model_parallel_all_reduce(fc2) if not self.skip_bias_add: fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2 return fc2 def forward_naive( self, hidden_states, residual: torch.Tensor | None = None, smooth_quant_scale: torch.Tensor | None = None ): ''' used by quant_tools ''' assert self.quant_config is None, "ffn naive forward dosen't support quantization" assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale" up_proj = getattr(self, self.up_proj_name) down_proj = getattr(self, self.down_proj_name) residual_ = None if self.tp_rank > 0 else residual fc1, bias = up_proj(hidden_states) if bias is not None: fc1 += bias fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated) out, bias = down_proj(fc1, residual=residual_) if self.skip_bias_add: return out, bias return out def forward( self, hidden_states, residual: torch.Tensor | None = None, smooth_quant_scale: torch.Tensor | None = None, use_tp_weight: bool = False, output: torch.Tensor | None = None, ): self.prepare_weight() if self.use_bt_ffn is False: return self.forward_naive(hidden_states, residual, None) up_proj = getattr(self, self.up_proj_name) down_proj = getattr(self, self.down_proj_name) residual_ = None if self.tp_rank > 0 else residual if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA) and not isinstance(down_proj, BaseLayerWithLoRA)): # The matmul formula is the following: # mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual # output = active(mul_out) # Notes: We cannot use the activation function in matmul because it does not support gated operation # we might support its in tmo matmul in the future up_proj_weight = up_proj.weight down_proj_weight = down_proj.weight if self.keep_full_weights and use_tp_weight: up_proj_weight = up_proj.tp_weight down_proj_weight = down_proj.tp_weight fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias, None, 'none', self.alpha, self.beta) act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype) beta = 0.0 if residual_ is not None: beta = 1.0 residual_ = residual_.view(-1, residual_.shape[-1]) out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta) # bias if existed need to add after second matmul according to the original design of vllm if self.reduce_results: out = tensor_model_parallel_all_reduce(out_, self.tp_group) else: out = out_ # do the bias add if needed if not self.skip_bias_add: out = out + down_proj.bias if down_proj.bias is not None else out else: return out, down_proj.bias else: fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight) if bias is not None: fc1 += bias input_scale= None if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8): down_proj.quant_method.skip_quant_input = True down_proj_smooth = down_proj.smooth if self.keep_full_weights and use_tp_weight: assert down_proj.tp_smooth is not None, "tp_smooth is not initialized" down_proj_smooth = down_proj.tp_smooth fc1, input_scale = mlu_ops.per_token_smooth_quantize( fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated) else: fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated) out, bias = down_proj( fc1, residual=residual_, smooth_quant_scale=input_scale, use_tp_weight=use_tp_weight, output=output) if self.skip_bias_add: return out, bias return out