forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
50
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh
Normal file
50
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh
Normal file
@@ -0,0 +1,50 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
#define CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Convert input to float32 and do gating operation.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param input: Input. Pointer to the MLU memory that stores the input,
|
||||
* the shape must be [input_row, hidden_size].
|
||||
* @param filter: Input. Pointer to the MLU memory that stores the weight,
|
||||
* the shape must be [expert_num, hidden_size].
|
||||
* @param output: Output. Pointer to the MLU memory that stores the output,
|
||||
* the shape must be [input_row, expert_num].
|
||||
* @param input_row: Input.
|
||||
* @param expert_num: Input.
|
||||
* @param hidden_size: Input.
|
||||
* @param a_dtype: Input. The data-type of input.
|
||||
* @param workspace: Input. Pointer to the MLU workspace.
|
||||
* @param workspace_size_bytes: Input. The size of workspace in bytes.
|
||||
* @note: a_dtype must be CNNL_DTYPE_BFLOAT16 or CNNL_DTYPE_HALF.
|
||||
* expert_num must be in range [1, 128].
|
||||
* If workspace is NOT NULL, workspace_size_bytes must NOT be smaller than 16 * 1024 * 1024.
|
||||
* The data-type of filter and output must be float.
|
||||
* cast_gating only supports MLU500 device or higher.
|
||||
*/
|
||||
KernelStatus invokeCastGating(cnrtQueue_t queue,
|
||||
void *input,
|
||||
void *filter,
|
||||
void *output,
|
||||
int input_row,
|
||||
int expert_num,
|
||||
int hidden_size,
|
||||
cnnlDataType_t a_dtype,
|
||||
void *workspace,
|
||||
size_t workspace_size_bytes);
|
||||
} // namespace tmo
|
||||
#endif // CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
Reference in New Issue
Block a user