157 lines
7.1 KiB
C++
157 lines
7.1 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
|
|
#include "common/utils.h"
|
|
#include "torch_ops_api.h"
|
|
|
|
namespace tmo {
|
|
namespace torch_api {
|
|
|
|
/*
|
|
torch_ffn_api API funtion in pytorch op is:
|
|
class FeedForward(torch.nn.Module):
|
|
def __init__(self, hidden_size: int, inner_size: int, act_mode: str,
|
|
bias = False, gated = False):
|
|
super(FeedForward, self).__init__()
|
|
self.up_linear = torch.nn.Linear(hidden_size, inner_size, bias)
|
|
self.gated = gated
|
|
if self.gated:
|
|
self.gated_linear = torch.nn.Linear(hidden_size, inner_size, bias)
|
|
self.down_linear = torch.nn.Linear(inner_size, hidden_size, bias)
|
|
self.act = act_mode_dict[act_mode]
|
|
|
|
def forward(self, x):
|
|
act_out = self.act(self.up_linear(x).float()).to(x.dtype)
|
|
return self.down_linear(act_out * self.gated_linear(x)) \
|
|
if self.gated else self.down_linear(act_out)
|
|
|
|
Demo of pytorch python module in TGI:
|
|
class LlamaBtMLP(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.act = config.hidden_act
|
|
self.gate_proj_weight = weights.get_multi_weights_col(
|
|
f"{prefix}.gate_proj", quantize=config.quantize, dim=dim
|
|
)
|
|
self.gate_proj_bias = None
|
|
self.up_proj_weight = weights.get_multi_weights_col(
|
|
f"{prefix}.up_proj", quantize=config.quantize, dim=dim
|
|
)
|
|
self.up_proj_bias = None
|
|
self.down_proj_weight = weights.get_multi_weights_col(
|
|
f"{prefix}.down_proj", quantize=config.quantize, dim=dim
|
|
)
|
|
self.down_proj_bias = None
|
|
|
|
def forward(self, hidden_states):
|
|
return tmo.ffn(hidden_states, self.up_proj_weight, self.up_proj_bias,
|
|
self.down_proj_weight, self.down_proj_bias, self.gate_proj_weight,
|
|
self.gate_proj_bias, self.act)
|
|
*/
|
|
|
|
// act_mode now only support silu, gelu and relu.
|
|
// std::string pytorch func
|
|
// silu --> nn.SiLU
|
|
// gelu --> nn.functional.gelu
|
|
// relu --> nn.ReLU
|
|
// Maybe aligned with transformer act mode later.
|
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L200
|
|
|
|
// Dimensions of ffn .
|
|
// Dimension of input: [batch_num, seq_len, hidden_size]
|
|
// Dimension of up_fc_filters: [filter_size, hidden_size]
|
|
// Dimension of up_fc_bias: [filter_size]
|
|
// Dimension of down_fc_filters: [hidden_size, filter_size]
|
|
// Dimension of down_fc_bias: [hidden_size]
|
|
// Dimension of gated_fc_filters: [filter_size, hidden_size] only for gated fnn
|
|
// Dimension of gated_fc_bias: [filter_size] only for gated fnn
|
|
// Dimension of layer norm weight: [seq_len, hidden_size] only for layer norm fused
|
|
// Dimension of layer norm bias: [seq_len, hidden_size] only for layer norm fused
|
|
// Dimension of output is same as input.
|
|
// Dimension of \b output: [batch_num, seq_len, hidden_size]
|
|
|
|
// Fused layer norm is not support now. And only support fused
|
|
// per layer norm later.
|
|
|
|
at::Tensor ffn(const at::Tensor &input,
|
|
const at::Tensor &up_fc_weight,
|
|
const c10::optional<at::Tensor> &up_fc_bias,
|
|
const at::Tensor &down_proj_weight,
|
|
const c10::optional<at::Tensor> &down_proj_bias,
|
|
const c10::optional<at::Tensor> &gate_up_proj_weight,
|
|
const c10::optional<at::Tensor> &gate_up_proj_bias,
|
|
const c10::optional<at::Tensor> &layernorm_weight,
|
|
const c10::optional<at::Tensor> &layernorm_bias,
|
|
const std::string &act_mode,
|
|
const std::string &residual_is,
|
|
double eps,
|
|
double alpha,
|
|
double beta) {
|
|
// Check tensor type and tensor device.
|
|
checkTensorSameAttr<TensorAttr::ALL>(input, up_fc_weight, up_fc_bias, down_proj_weight,
|
|
down_proj_bias, gate_up_proj_weight, gate_up_proj_bias,
|
|
layernorm_weight, layernorm_bias);
|
|
// Check dims.
|
|
const int64_t nDim = input.dim();
|
|
at::Tensor input_view = input;
|
|
if (nDim == 2) input_view = input_view.unsqueeze(0);
|
|
|
|
// Check contiguous.
|
|
CHECK_TENSOR_CONTIGUOUS(input_view)
|
|
|
|
const torch_mlu::mlu::MLUGuard device_guard(input_view.device());
|
|
at::Tensor output = at::empty_like(input_view);
|
|
|
|
// Convert torch tensor to tensor desc
|
|
auto descs = createTensorDescs(
|
|
{input_view, up_fc_weight, up_fc_bias.value_or(torch::Tensor()), down_proj_weight,
|
|
down_proj_bias.value_or(torch::Tensor()), gate_up_proj_weight.value_or(torch::Tensor()),
|
|
gate_up_proj_bias.value_or(torch::Tensor()), layernorm_weight.value_or(torch::Tensor()),
|
|
layernorm_bias.value_or(torch::Tensor()), output});
|
|
|
|
bool has_ln = layernorm_weight.has_value();
|
|
TORCH_CHECK(residual_is == "input" || residual_is == "normed_input" || residual_is == "none",
|
|
"residual_is must be 'input' or 'normed_input' or 'none'.")
|
|
bool has_residual = residual_is != "none";
|
|
auto ln_res_mode = tmo::lnres::makeLnresEnum(has_ln, has_residual, residual_is == "input");
|
|
auto compute_type = getCnnlDataType(input_view.scalar_type());
|
|
|
|
tmo::op_desc::FeedForwardDesc ffn_desc(ln_res_mode, act_mode, compute_type, eps, alpha, beta);
|
|
if (gate_up_proj_weight.has_value() && gate_up_proj_weight->defined()) {
|
|
CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptorGateFiltersBias(
|
|
ffn_desc, descs[5].get(), getAtTensorPtr(gate_up_proj_weight), descs[6].get(),
|
|
getAtTensorPtr(gate_up_proj_bias)));
|
|
}
|
|
|
|
// Get current handle.
|
|
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
|
// Get workspace size and malloc workspace.
|
|
size_t workspace_size = 0;
|
|
CNNL_CHECK_FATAL(cnnlGetTransformerFeedForwardWorkspaceSize_v2(
|
|
handle, ffn_desc, descs[0].get(), descs[1].get(), nullptr, ffn_desc, &workspace_size));
|
|
auto workspace =
|
|
at::empty({static_cast<int64_t>(workspace_size)}, input.options().dtype(at::kByte));
|
|
|
|
// forward
|
|
CNNL_CHECK_FATAL(cnnlTransformerFeedForward(
|
|
handle, ffn_desc, ffn_desc, nullptr, descs[0].get(), getAtTensorPtr(input_view),
|
|
descs[1].get(), getAtTensorPtr(up_fc_weight), descs[2].get(), getAtTensorPtr(up_fc_bias),
|
|
descs[3].get(), getAtTensorPtr(down_proj_weight), descs[4].get(),
|
|
getAtTensorPtr(down_proj_bias), descs[7].get(), getAtTensorPtr(layernorm_weight),
|
|
descs[8].get(), getAtTensorPtr(layernorm_bias), getAtTensorPtr(workspace), workspace_size,
|
|
descs[9].get(), getAtTensorPtr(output)));
|
|
return nDim == 2 ? output.squeeze_(0) : output;
|
|
}
|
|
|
|
} // namespace torch_api
|
|
} // namespace tmo
|