130 lines
7.3 KiB
Plaintext
130 lines
7.3 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_
|
|
#define CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_
|
|
|
|
#include "cnnl.h"
|
|
#include "kernel_utils.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief Apply rotary embedding.
|
|
* @param queue: The queue for mlu.
|
|
* @param output: Output. Pointer to the MLU memory that stores the output,
|
|
* the shape must be [total_seq_len, head_num, head_size]
|
|
* @param input: Input. Pointer to the MLU memory that stores the input
|
|
* the shape must be [total_seq_len, head_num, head_size].
|
|
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
|
|
* continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If
|
|
* dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim].
|
|
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
|
|
* continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If
|
|
* dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim].
|
|
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
|
|
* batch. If discrete is true, the shape must be [total_seq_len]. If discrete is false, the shape
|
|
* must be [batch]. Seq_offsets could be nullptr if discrete is false, which means no offset for
|
|
* each batch.
|
|
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length
|
|
* of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all
|
|
* batches is max_seq_Len.
|
|
* @param batch: Batch size.
|
|
* @param max_seq_len: The maximum sequence length of input.
|
|
* @param head_num: Head number.
|
|
* @param head_size: Head size.
|
|
* @param rotary_seq_len: The rotary seq_len of sin_table and cos_table.
|
|
* @param rotary_dim: The rotary dimension of sin_table and cos_table.
|
|
* @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table.
|
|
* @param input_seq_stride: The stride of total_seq_len in input.
|
|
* @param input_head_stride: The stride of head_num in input.
|
|
* @param output_seq_stride: The stride of total_seq_len in output.
|
|
* @param output_head_stride: The stride of head_num in output.
|
|
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
|
|
* @param discrete: A boolean value indicates whether all input tokens have offsets.
|
|
* @param dynamic_ntk: A boolean value indicates whether all batches have different sin_table and
|
|
* cos_table.
|
|
* @param data_type: Data type of all inputs and outputs.
|
|
*/
|
|
KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue,
|
|
void *output,
|
|
const void *input,
|
|
const void *sin_table,
|
|
const void *cos_table,
|
|
const int *seq_offsets,
|
|
const int *cu_seq_lens,
|
|
int batch,
|
|
int max_seq_len,
|
|
int head_num,
|
|
int head_size,
|
|
int rotary_seq_len,
|
|
int rotary_dim,
|
|
int rotary_stride,
|
|
size_t input_seq_stride,
|
|
size_t input_head_stride,
|
|
size_t output_seq_stride,
|
|
size_t output_head_stride,
|
|
bool interleaved,
|
|
bool discrete,
|
|
bool dynamic_ntk,
|
|
cnnlDataType_t data_type);
|
|
|
|
/**
|
|
* @brief Apply rotary embedding.
|
|
* @param queue: The queue for mlu.
|
|
* @param output: Output. Pointer to the MLU memory that stores the output,
|
|
* the shape must be [total_seq_len, head_num, head_size]
|
|
* @param input: Input. Pointer to the MLU memory that stores the input
|
|
* the shape must be [total_seq_len, head_num, head_size].
|
|
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
|
|
* continous. The shape must be [rotary_seq_len, head_size / 2].
|
|
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
|
|
* continous. The shape must be [rotary_seq_len, head_size / 2].
|
|
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
|
|
* batch. The Shape must be [2, total_seq_len].
|
|
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length
|
|
* of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all
|
|
* batches is max_seq_Len.
|
|
* @param batch: Batch size.
|
|
* @param max_seq_len: The maximum sequence length of input.
|
|
* @param head_num: Head number.
|
|
* @param head_size: Head size.
|
|
* @param rotary_seq_len: The rotary seq_len of sin_table and cos_table.
|
|
* @param rotary_stride: The stride of rotary_seq_len stride in sin_table and cos_table.
|
|
* @param input_seq_stride: The stride of total_seq_len in input.
|
|
* @param input_head_stride: The stride of head_num in input.
|
|
* @param output_seq_stride: The stride of total_seq_len in output.
|
|
* @param output_head_stride: The stride of head_num in output.
|
|
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
|
|
* @param data_type: Data type of all inputs and outputs.
|
|
*/
|
|
KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue,
|
|
void *output,
|
|
const void *input,
|
|
const void *sin_table,
|
|
const void *cos_table,
|
|
const int *seq_offsets,
|
|
const int *cu_seq_lens,
|
|
int batch,
|
|
int max_seq_len,
|
|
int total_seq_len,
|
|
int head_num,
|
|
int head_size,
|
|
int rotary_seq_len,
|
|
int rotary_stride,
|
|
size_t input_seq_stride,
|
|
size_t input_head_stride,
|
|
size_t output_seq_stride,
|
|
size_t output_head_stride,
|
|
bool interleaved,
|
|
cnnlDataType_t data_type);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_
|