Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/rotary_embedding.mluh
2026-02-04 17:39:32 +08:00

130 lines
7.3 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_
#define CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Apply rotary embedding.
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [total_seq_len, head_num, head_size]
* @param input: Input. Pointer to the MLU memory that stores the input
* the shape must be [total_seq_len, head_num, head_size].
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
* continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If
* dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim].
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
* continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If
* dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim].
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
* batch. If discrete is true, the shape must be [total_seq_len]. If discrete is false, the shape
* must be [batch]. Seq_offsets could be nullptr if discrete is false, which means no offset for
* each batch.
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length
* of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all
* batches is max_seq_Len.
* @param batch: Batch size.
* @param max_seq_len: The maximum sequence length of input.
* @param head_num: Head number.
* @param head_size: Head size.
* @param rotary_seq_len: The rotary seq_len of sin_table and cos_table.
* @param rotary_dim: The rotary dimension of sin_table and cos_table.
* @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table.
* @param input_seq_stride: The stride of total_seq_len in input.
* @param input_head_stride: The stride of head_num in input.
* @param output_seq_stride: The stride of total_seq_len in output.
* @param output_head_stride: The stride of head_num in output.
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
* @param discrete: A boolean value indicates whether all input tokens have offsets.
* @param dynamic_ntk: A boolean value indicates whether all batches have different sin_table and
* cos_table.
* @param data_type: Data type of all inputs and outputs.
*/
KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue,
void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool interleaved,
bool discrete,
bool dynamic_ntk,
cnnlDataType_t data_type);
/**
* @brief Apply rotary embedding.
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [total_seq_len, head_num, head_size]
* @param input: Input. Pointer to the MLU memory that stores the input
* the shape must be [total_seq_len, head_num, head_size].
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
* continous. The shape must be [rotary_seq_len, head_size / 2].
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
* continous. The shape must be [rotary_seq_len, head_size / 2].
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
* batch. The Shape must be [2, total_seq_len].
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length
* of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all
* batches is max_seq_Len.
* @param batch: Batch size.
* @param max_seq_len: The maximum sequence length of input.
* @param head_num: Head number.
* @param head_size: Head size.
* @param rotary_seq_len: The rotary seq_len of sin_table and cos_table.
* @param rotary_stride: The stride of rotary_seq_len stride in sin_table and cos_table.
* @param input_seq_stride: The stride of total_seq_len in input.
* @param input_head_stride: The stride of head_num in input.
* @param output_seq_stride: The stride of total_seq_len in output.
* @param output_head_stride: The stride of head_num in output.
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
* @param data_type: Data type of all inputs and outputs.
*/
KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue,
void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int total_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool interleaved,
cnnlDataType_t data_type);
} // namespace tmo
#endif // CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_