/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include "kernels/moe/add_bias_activation.mluh" #include "torch_ops_api.h" namespace tmo { namespace torch_api { void active(const torch::Tensor &input, const torch::Tensor &output, const c10::optional &bias, const c10::optional &cusum_token_count, const std::string &act_mode, bool is_gated, int64_t start_expert_id, int64_t expert_size, double active_coef) { TORCH_CHECK( act_mode == "silu" || act_mode == "gelu" || act_mode == "quick_gelu" || act_mode == "swish", "act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.") cnnlActivationMode_t act_type = act_mode == "gelu" ? CNNL_ACTIVATION_GELU : CNNL_ACTIVATION_SWISH; if (act_mode == "quick_gelu") { active_coef = 1.702f; } else if (act_mode == "silu") { active_coef = 1.0f; } TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2") auto input_shape = input.sizes(); int64_t in_channel = input_shape.back(); TORCH_CHECK(in_channel > 0, "in_channel > 0") if (is_gated) { TORCH_CHECK(in_channel % 2 == 0, "in_channel % 2 == 0 if is_gated is true") } int64_t total_tokens = input.numel() / in_channel; int64_t inner_size = is_gated ? in_channel / 2 : in_channel; int64_t num_expert = cusum_token_count.has_value() ? (cusum_token_count.value().size(0) - 1) : 0; const torch_mlu::mlu::MLUGuard device_guard(input.device()); int64_t output_stride = output.stride(-2); auto data_dtype = getCnnlDataType(input.scalar_type()); auto queue = torch_mlu::getCurMLUStream(); tmo::invokeGroupAddBiasActivationKernel( queue, getAtTensorPtr(output), getAtTensorPtr(input), getAtTensorPtr(bias), (int *)getAtTensorPtr(cusum_token_count), num_expert, total_tokens, inner_size, output_stride, data_dtype, is_gated, act_type, start_expert_id, expert_size, active_coef); } } // namespace torch_api } // namespace tmo