/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_ #define CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_ #include "cnnl.h" #include "kernel_utils.h" namespace tmo { /** * @brief Apply rotary embedding. * @param queue: The queue for mlu. * @param output: Output. Pointer to the MLU memory that stores the output, * the shape must be [total_seq_len, head_num, head_size] * @param input: Input. Pointer to the MLU memory that stores the input * the shape must be [total_seq_len, head_num, head_size]. * @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be * continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If * dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim]. * @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be * continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If * dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim]. * @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each * batch. If discrete is true, the shape must be [total_seq_len]. If discrete is false, the shape * must be [batch]. Seq_offsets could be nullptr if discrete is false, which means no offset for * each batch. * @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length * of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all * batches is max_seq_Len. * @param batch: Batch size. * @param max_seq_len: The maximum sequence length of input. * @param head_num: Head number. * @param head_size: Head size. * @param rotary_seq_len: The rotary seq_len of sin_table and cos_table. * @param rotary_dim: The rotary dimension of sin_table and cos_table. * @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table. * @param input_seq_stride: The stride of total_seq_len in input. * @param input_head_stride: The stride of head_num in input. * @param output_seq_stride: The stride of total_seq_len in output. * @param output_head_stride: The stride of head_num in output. * @param interleaved: A boolean value indicates compute mode of rotary embedding. * @param discrete: A boolean value indicates whether all input tokens have offsets. * @param dynamic_ntk: A boolean value indicates whether all batches have different sin_table and * cos_table. * @param data_type: Data type of all inputs and outputs. */ KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue, void *output, const void *input, const void *sin_table, const void *cos_table, const int *seq_offsets, const int *cu_seq_lens, int batch, int max_seq_len, int head_num, int head_size, int rotary_seq_len, int rotary_dim, int rotary_stride, size_t input_seq_stride, size_t input_head_stride, size_t output_seq_stride, size_t output_head_stride, bool interleaved, bool discrete, bool dynamic_ntk, cnnlDataType_t data_type); /** * @brief Apply rotary embedding. * @param queue: The queue for mlu. * @param output: Output. Pointer to the MLU memory that stores the output, * the shape must be [total_seq_len, head_num, head_size] * @param input: Input. Pointer to the MLU memory that stores the input * the shape must be [total_seq_len, head_num, head_size]. * @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be * continous. The shape must be [rotary_seq_len, head_size / 2]. * @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be * continous. The shape must be [rotary_seq_len, head_size / 2]. * @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each * batch. The Shape must be [2, total_seq_len]. * @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length * of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all * batches is max_seq_Len. * @param batch: Batch size. * @param max_seq_len: The maximum sequence length of input. * @param head_num: Head number. * @param head_size: Head size. * @param rotary_seq_len: The rotary seq_len of sin_table and cos_table. * @param rotary_stride: The stride of rotary_seq_len stride in sin_table and cos_table. * @param input_seq_stride: The stride of total_seq_len in input. * @param input_head_stride: The stride of head_num in input. * @param output_seq_stride: The stride of total_seq_len in output. * @param output_head_stride: The stride of head_num in output. * @param interleaved: A boolean value indicates compute mode of rotary embedding. * @param data_type: Data type of all inputs and outputs. */ KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue, void *output, const void *input, const void *sin_table, const void *cos_table, const int *seq_offsets, const int *cu_seq_lens, int batch, int max_seq_len, int total_seq_len, int head_num, int head_size, int rotary_seq_len, int rotary_stride, size_t input_seq_stride, size_t input_head_stride, size_t output_seq_stride, size_t output_head_stride, bool interleaved, cnnlDataType_t data_type); } // namespace tmo #endif // CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_