59 lines
2.6 KiB
C++
59 lines
2.6 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
|
|
#include "kernels/moe/add_bias_activation.mluh"
|
|
#include "torch_ops_api.h"
|
|
|
|
namespace tmo {
|
|
namespace torch_api {
|
|
|
|
void active(const torch::Tensor &input,
|
|
const torch::Tensor &output,
|
|
const c10::optional<torch::Tensor> &bias,
|
|
const c10::optional<torch::Tensor> &cusum_token_count,
|
|
const std::string &act_mode,
|
|
bool is_gated,
|
|
int64_t start_expert_id,
|
|
int64_t expert_size,
|
|
double active_coef) {
|
|
TORCH_CHECK(
|
|
act_mode == "silu" || act_mode == "gelu" || act_mode == "quick_gelu" || act_mode == "swish",
|
|
"act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.")
|
|
cnnlActivationMode_t act_type = act_mode == "gelu" ? CNNL_ACTIVATION_GELU : CNNL_ACTIVATION_SWISH;
|
|
if (act_mode == "quick_gelu") {
|
|
active_coef = 1.702f;
|
|
} else if (act_mode == "silu") {
|
|
active_coef = 1.0f;
|
|
}
|
|
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
|
|
auto input_shape = input.sizes();
|
|
int64_t in_channel = input_shape.back();
|
|
TORCH_CHECK(in_channel > 0, "in_channel > 0")
|
|
if (is_gated) {
|
|
TORCH_CHECK(in_channel % 2 == 0, "in_channel % 2 == 0 if is_gated is true")
|
|
}
|
|
int64_t total_tokens = input.numel() / in_channel;
|
|
int64_t inner_size = is_gated ? in_channel / 2 : in_channel;
|
|
int64_t num_expert = cusum_token_count.has_value() ? (cusum_token_count.value().size(0) - 1) : 0;
|
|
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
|
int64_t output_stride = output.stride(-2);
|
|
auto data_dtype = getCnnlDataType(input.scalar_type());
|
|
auto queue = torch_mlu::getCurMLUStream();
|
|
tmo::invokeGroupAddBiasActivationKernel(
|
|
queue, getAtTensorPtr(output), getAtTensorPtr(input), getAtTensorPtr(bias),
|
|
(int *)getAtTensorPtr(cusum_token_count), num_expert, total_tokens, inner_size, output_stride,
|
|
data_dtype, is_gated, act_type, start_expert_id, expert_size, active_coef);
|
|
}
|
|
|
|
} // namespace torch_api
|
|
} // namespace tmo
|