/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include "common/utils.h" #include "torch_ops_api.h" namespace tmo { namespace torch_api { /* torch_ffn_api API funtion in pytorch op is: class FeedForward(torch.nn.Module): def __init__(self, hidden_size: int, inner_size: int, act_mode: str, bias = False, gated = False): super(FeedForward, self).__init__() self.up_linear = torch.nn.Linear(hidden_size, inner_size, bias) self.gated = gated if self.gated: self.gated_linear = torch.nn.Linear(hidden_size, inner_size, bias) self.down_linear = torch.nn.Linear(inner_size, hidden_size, bias) self.act = act_mode_dict[act_mode] def forward(self, x): act_out = self.act(self.up_linear(x).float()).to(x.dtype) return self.down_linear(act_out * self.gated_linear(x)) \ if self.gated else self.down_linear(act_out) Demo of pytorch python module in TGI: class LlamaBtMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.act = config.hidden_act self.gate_proj_weight = weights.get_multi_weights_col( f"{prefix}.gate_proj", quantize=config.quantize, dim=dim ) self.gate_proj_bias = None self.up_proj_weight = weights.get_multi_weights_col( f"{prefix}.up_proj", quantize=config.quantize, dim=dim ) self.up_proj_bias = None self.down_proj_weight = weights.get_multi_weights_col( f"{prefix}.down_proj", quantize=config.quantize, dim=dim ) self.down_proj_bias = None def forward(self, hidden_states): return tmo.ffn(hidden_states, self.up_proj_weight, self.up_proj_bias, self.down_proj_weight, self.down_proj_bias, self.gate_proj_weight, self.gate_proj_bias, self.act) */ // act_mode now only support silu, gelu and relu. // std::string pytorch func // silu --> nn.SiLU // gelu --> nn.functional.gelu // relu --> nn.ReLU // Maybe aligned with transformer act mode later. // https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L200 // Dimensions of ffn . // Dimension of input: [batch_num, seq_len, hidden_size] // Dimension of up_fc_filters: [filter_size, hidden_size] // Dimension of up_fc_bias: [filter_size] // Dimension of down_fc_filters: [hidden_size, filter_size] // Dimension of down_fc_bias: [hidden_size] // Dimension of gated_fc_filters: [filter_size, hidden_size] only for gated fnn // Dimension of gated_fc_bias: [filter_size] only for gated fnn // Dimension of layer norm weight: [seq_len, hidden_size] only for layer norm fused // Dimension of layer norm bias: [seq_len, hidden_size] only for layer norm fused // Dimension of output is same as input. // Dimension of \b output: [batch_num, seq_len, hidden_size] // Fused layer norm is not support now. And only support fused // per layer norm later. at::Tensor ffn(const at::Tensor &input, const at::Tensor &up_fc_weight, const c10::optional &up_fc_bias, const at::Tensor &down_proj_weight, const c10::optional &down_proj_bias, const c10::optional &gate_up_proj_weight, const c10::optional &gate_up_proj_bias, const c10::optional &layernorm_weight, const c10::optional &layernorm_bias, const std::string &act_mode, const std::string &residual_is, double eps, double alpha, double beta) { // Check tensor type and tensor device. checkTensorSameAttr(input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias, gate_up_proj_weight, gate_up_proj_bias, layernorm_weight, layernorm_bias); // Check dims. const int64_t nDim = input.dim(); at::Tensor input_view = input; if (nDim == 2) input_view = input_view.unsqueeze(0); // Check contiguous. CHECK_TENSOR_CONTIGUOUS(input_view) const torch_mlu::mlu::MLUGuard device_guard(input_view.device()); at::Tensor output = at::empty_like(input_view); // Convert torch tensor to tensor desc auto descs = createTensorDescs( {input_view, up_fc_weight, up_fc_bias.value_or(torch::Tensor()), down_proj_weight, down_proj_bias.value_or(torch::Tensor()), gate_up_proj_weight.value_or(torch::Tensor()), gate_up_proj_bias.value_or(torch::Tensor()), layernorm_weight.value_or(torch::Tensor()), layernorm_bias.value_or(torch::Tensor()), output}); bool has_ln = layernorm_weight.has_value(); TORCH_CHECK(residual_is == "input" || residual_is == "normed_input" || residual_is == "none", "residual_is must be 'input' or 'normed_input' or 'none'.") bool has_residual = residual_is != "none"; auto ln_res_mode = tmo::lnres::makeLnresEnum(has_ln, has_residual, residual_is == "input"); auto compute_type = getCnnlDataType(input_view.scalar_type()); tmo::op_desc::FeedForwardDesc ffn_desc(ln_res_mode, act_mode, compute_type, eps, alpha, beta); if (gate_up_proj_weight.has_value() && gate_up_proj_weight->defined()) { CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptorGateFiltersBias( ffn_desc, descs[5].get(), getAtTensorPtr(gate_up_proj_weight), descs[6].get(), getAtTensorPtr(gate_up_proj_bias))); } // Get current handle. cnnlHandle_t handle = torch_mlu::getCurrentHandle(); // Get workspace size and malloc workspace. size_t workspace_size = 0; CNNL_CHECK_FATAL(cnnlGetTransformerFeedForwardWorkspaceSize_v2( handle, ffn_desc, descs[0].get(), descs[1].get(), nullptr, ffn_desc, &workspace_size)); auto workspace = at::empty({static_cast(workspace_size)}, input.options().dtype(at::kByte)); // forward CNNL_CHECK_FATAL(cnnlTransformerFeedForward( handle, ffn_desc, ffn_desc, nullptr, descs[0].get(), getAtTensorPtr(input_view), descs[1].get(), getAtTensorPtr(up_fc_weight), descs[2].get(), getAtTensorPtr(up_fc_bias), descs[3].get(), getAtTensorPtr(down_proj_weight), descs[4].get(), getAtTensorPtr(down_proj_bias), descs[7].get(), getAtTensorPtr(layernorm_weight), descs[8].get(), getAtTensorPtr(layernorm_bias), getAtTensorPtr(workspace), workspace_size, descs[9].get(), getAtTensorPtr(output))); return nDim == 2 ? output.squeeze_(0) : output; } } // namespace torch_api } // namespace tmo