forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
58
torch_mlu_ops-v1.3.2/csrc/torch_api/active.cpp
Normal file
58
torch_mlu_ops-v1.3.2/csrc/torch_api/active.cpp
Normal file
@@ -0,0 +1,58 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernels/moe/add_bias_activation.mluh"
|
||||
#include "torch_ops_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
void active(const torch::Tensor &input,
|
||||
const torch::Tensor &output,
|
||||
const c10::optional<torch::Tensor> &bias,
|
||||
const c10::optional<torch::Tensor> &cusum_token_count,
|
||||
const std::string &act_mode,
|
||||
bool is_gated,
|
||||
int64_t start_expert_id,
|
||||
int64_t expert_size,
|
||||
double active_coef) {
|
||||
TORCH_CHECK(
|
||||
act_mode == "silu" || act_mode == "gelu" || act_mode == "quick_gelu" || act_mode == "swish",
|
||||
"act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.")
|
||||
cnnlActivationMode_t act_type = act_mode == "gelu" ? CNNL_ACTIVATION_GELU : CNNL_ACTIVATION_SWISH;
|
||||
if (act_mode == "quick_gelu") {
|
||||
active_coef = 1.702f;
|
||||
} else if (act_mode == "silu") {
|
||||
active_coef = 1.0f;
|
||||
}
|
||||
TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2")
|
||||
auto input_shape = input.sizes();
|
||||
int64_t in_channel = input_shape.back();
|
||||
TORCH_CHECK(in_channel > 0, "in_channel > 0")
|
||||
if (is_gated) {
|
||||
TORCH_CHECK(in_channel % 2 == 0, "in_channel % 2 == 0 if is_gated is true")
|
||||
}
|
||||
int64_t total_tokens = input.numel() / in_channel;
|
||||
int64_t inner_size = is_gated ? in_channel / 2 : in_channel;
|
||||
int64_t num_expert = cusum_token_count.has_value() ? (cusum_token_count.value().size(0) - 1) : 0;
|
||||
const torch_mlu::mlu::MLUGuard device_guard(input.device());
|
||||
int64_t output_stride = output.stride(-2);
|
||||
auto data_dtype = getCnnlDataType(input.scalar_type());
|
||||
auto queue = torch_mlu::getCurMLUStream();
|
||||
tmo::invokeGroupAddBiasActivationKernel(
|
||||
queue, getAtTensorPtr(output), getAtTensorPtr(input), getAtTensorPtr(bias),
|
||||
(int *)getAtTensorPtr(cusum_token_count), num_expert, total_tokens, inner_size, output_stride,
|
||||
data_dtype, is_gated, act_type, start_expert_id, expert_size, active_coef);
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user