Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/torch_api/active.cpp
2026-02-04 17:39:32 +08:00

59 lines
2.6 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/moe/add_bias_activation.mluh"
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
void active(const torch::Tensor &input,
const torch::Tensor &output,
const c10::optional<torch::Tensor> &bias,
const c10::optional<torch::Tensor> &cusum_token_count,
const std::string &act_mode,
bool is_gated,
int64_t start_expert_id,
int64_t expert_size,
double active_coef) {
TORCH_CHECK(
act_mode == "silu" || act_mode == "gelu" || act_mode == "quick_gelu" || act_mode == "swish",
"act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.")
cnnlActivationMode_t act_type = act_mode == "gelu" ? CNNL_ACTIVATION_GELU : CNNL_ACTIVATION_SWISH;
if (act_mode == "quick_gelu") {
active_coef = 1.702f;
} else if (act_mode == "silu") {
active_coef = 1.0f;
}
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
auto input_shape = input.sizes();
int64_t in_channel = input_shape.back();
TORCH_CHECK(in_channel > 0, "in_channel > 0")
if (is_gated) {
TORCH_CHECK(in_channel % 2 == 0, "in_channel % 2 == 0 if is_gated is true")
}
int64_t total_tokens = input.numel() / in_channel;
int64_t inner_size = is_gated ? in_channel / 2 : in_channel;
int64_t num_expert = cusum_token_count.has_value() ? (cusum_token_count.value().size(0) - 1) : 0;
const torch_mlu::mlu::MLUGuard device_guard(input.device());
int64_t output_stride = output.stride(-2);
auto data_dtype = getCnnlDataType(input.scalar_type());
auto queue = torch_mlu::getCurMLUStream();
tmo::invokeGroupAddBiasActivationKernel(
queue, getAtTensorPtr(output), getAtTensorPtr(input), getAtTensorPtr(bias),
(int *)getAtTensorPtr(cusum_token_count), num_expert, total_tokens, inner_size, output_stride,
data_dtype, is_gated, act_type, start_expert_id, expert_size, active_coef);
}
} // namespace torch_api
} // namespace tmo