Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/torch_api/ffn.cpp
2026-02-04 17:39:32 +08:00

157 lines
7.1 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common/utils.h"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
/*
torch_ffn_api API funtion in pytorch op is:
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size: int, inner_size: int, act_mode: str,
bias = False, gated = False):
super(FeedForward, self).__init__()
self.up_linear = torch.nn.Linear(hidden_size, inner_size, bias)
self.gated = gated
if self.gated:
self.gated_linear = torch.nn.Linear(hidden_size, inner_size, bias)
self.down_linear = torch.nn.Linear(inner_size, hidden_size, bias)
self.act = act_mode_dict[act_mode]
def forward(self, x):
act_out = self.act(self.up_linear(x).float()).to(x.dtype)
return self.down_linear(act_out * self.gated_linear(x)) \
if self.gated else self.down_linear(act_out)
Demo of pytorch python module in TGI:
class LlamaBtMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.act = config.hidden_act
self.gate_proj_weight = weights.get_multi_weights_col(
f"{prefix}.gate_proj", quantize=config.quantize, dim=dim
)
self.gate_proj_bias = None
self.up_proj_weight = weights.get_multi_weights_col(
f"{prefix}.up_proj", quantize=config.quantize, dim=dim
)
self.up_proj_bias = None
self.down_proj_weight = weights.get_multi_weights_col(
f"{prefix}.down_proj", quantize=config.quantize, dim=dim
)
self.down_proj_bias = None
def forward(self, hidden_states):
return tmo.ffn(hidden_states, self.up_proj_weight, self.up_proj_bias,
self.down_proj_weight, self.down_proj_bias, self.gate_proj_weight,
self.gate_proj_bias, self.act)
*/
// act_mode now only support silu, gelu and relu.
// std::string pytorch func
// silu --> nn.SiLU
// gelu --> nn.functional.gelu
// relu --> nn.ReLU
// Maybe aligned with transformer act mode later.
// https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L200
// Dimensions of ffn .
// Dimension of input: [batch_num, seq_len, hidden_size]
// Dimension of up_fc_filters: [filter_size, hidden_size]
// Dimension of up_fc_bias: [filter_size]
// Dimension of down_fc_filters: [hidden_size, filter_size]
// Dimension of down_fc_bias: [hidden_size]
// Dimension of gated_fc_filters: [filter_size, hidden_size] only for gated fnn
// Dimension of gated_fc_bias: [filter_size] only for gated fnn
// Dimension of layer norm weight: [seq_len, hidden_size] only for layer norm fused
// Dimension of layer norm bias: [seq_len, hidden_size] only for layer norm fused
// Dimension of output is same as input.
// Dimension of \b output: [batch_num, seq_len, hidden_size]
// Fused layer norm is not support now. And only support fused
// per layer norm later.
at::Tensor ffn(const at::Tensor &input,
const at::Tensor &up_fc_weight,
const c10::optional<at::Tensor> &up_fc_bias,
const at::Tensor &down_proj_weight,
const c10::optional<at::Tensor> &down_proj_bias,
const c10::optional<at::Tensor> &gate_up_proj_weight,
const c10::optional<at::Tensor> &gate_up_proj_bias,
const c10::optional<at::Tensor> &layernorm_weight,
const c10::optional<at::Tensor> &layernorm_bias,
const std::string &act_mode,
const std::string &residual_is,
double eps,
double alpha,
double beta) {
// Check tensor type and tensor device.
checkTensorSameAttr<TensorAttr::ALL>(input, up_fc_weight, up_fc_bias, down_proj_weight,
down_proj_bias, gate_up_proj_weight, gate_up_proj_bias,
layernorm_weight, layernorm_bias);
// Check dims.
const int64_t nDim = input.dim();
at::Tensor input_view = input;
if (nDim == 2) input_view = input_view.unsqueeze(0);
// Check contiguous.
CHECK_TENSOR_CONTIGUOUS(input_view)
const torch_mlu::mlu::MLUGuard device_guard(input_view.device());
at::Tensor output = at::empty_like(input_view);
// Convert torch tensor to tensor desc
auto descs = createTensorDescs(
{input_view, up_fc_weight, up_fc_bias.value_or(torch::Tensor()), down_proj_weight,
down_proj_bias.value_or(torch::Tensor()), gate_up_proj_weight.value_or(torch::Tensor()),
gate_up_proj_bias.value_or(torch::Tensor()), layernorm_weight.value_or(torch::Tensor()),
layernorm_bias.value_or(torch::Tensor()), output});
bool has_ln = layernorm_weight.has_value();
TORCH_CHECK(residual_is == "input" || residual_is == "normed_input" || residual_is == "none",
"residual_is must be 'input' or 'normed_input' or 'none'.")
bool has_residual = residual_is != "none";
auto ln_res_mode = tmo::lnres::makeLnresEnum(has_ln, has_residual, residual_is == "input");
auto compute_type = getCnnlDataType(input_view.scalar_type());
tmo::op_desc::FeedForwardDesc ffn_desc(ln_res_mode, act_mode, compute_type, eps, alpha, beta);
if (gate_up_proj_weight.has_value() && gate_up_proj_weight->defined()) {
CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptorGateFiltersBias(
ffn_desc, descs[5].get(), getAtTensorPtr(gate_up_proj_weight), descs[6].get(),
getAtTensorPtr(gate_up_proj_bias)));
}
// Get current handle.
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
// Get workspace size and malloc workspace.
size_t workspace_size = 0;
CNNL_CHECK_FATAL(cnnlGetTransformerFeedForwardWorkspaceSize_v2(
handle, ffn_desc, descs[0].get(), descs[1].get(), nullptr, ffn_desc, &workspace_size));
auto workspace =
at::empty({static_cast<int64_t>(workspace_size)}, input.options().dtype(at::kByte));
// forward
CNNL_CHECK_FATAL(cnnlTransformerFeedForward(
handle, ffn_desc, ffn_desc, nullptr, descs[0].get(), getAtTensorPtr(input_view),
descs[1].get(), getAtTensorPtr(up_fc_weight), descs[2].get(), getAtTensorPtr(up_fc_bias),
descs[3].get(), getAtTensorPtr(down_proj_weight), descs[4].get(),
getAtTensorPtr(down_proj_bias), descs[7].get(), getAtTensorPtr(layernorm_weight),
descs[8].get(), getAtTensorPtr(layernorm_bias), getAtTensorPtr(workspace), workspace_size,
descs[9].get(), getAtTensorPtr(output)));
return nDim == 2 ? output.squeeze_(0) : output;
}
} // namespace torch_api
} // namespace tmo