Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mluh
2026-02-04 17:39:32 +08:00

120 lines
7.3 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
#define CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Apply query and kery rotary embedding, key layernorm and
* quantize key and value to kv cache.
* @param queue: The queue for mlu.
* @param input: Input/Output. Pointer to the MLU memory that stores the input,
* the shape must be [batch_size, 1, head_num_q + head_num_kv * 2, head_size].
* @param key_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision key
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param value_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param key_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision key
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param value_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
* continous. The shape must be [rotary_seq_len, rotary_dim].
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
* continous. The shape must be [rotary_seq_len, rotary_dim].
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
* batch. The shape must be [batch].
* @param norm_gamma: Input. Pointer to the MLU memory that stores the gamma param of layernorm.
* @param norm_beta: Input. Pointer to the MLU memory that stores the beta param of layernorm.
* @param key_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
* key. The shape must be [head_num_kv, head_size]. If key_scale is nullptr,
* that means key do not need to be quantized.
* @param value_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
* value. The shape must be [head_num_kv, head_size]. If value_scale is nullptr,
* that means value do not need to be quantized.
* @param key_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
* precision key. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
* [num_blocks, head_num_kv, block_size, group_num].
* @param value_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
* precision value. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
* [num_blocks, head_num_kv, block_size, group_num].
* @param cache_bs_id_hp: Input. Pointer to the MLU memory that stores the batch
* offset of high precision cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets_hp: Input. Pointer to the MLU memory that stores the sequence
* offset of high precision cache, the shape must be [batch].
* @param cache_bs_id_lp: Input. Pointer to the MLU memory that stores the batch
* offset of low precision cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets_lp: Input. Pointer to the MLU memory that stores the sequence
* offset of low precision cache, the shape must be [batch].
* @param slot_mapping_hp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
* which has shape [batch]. Data type of slot mapping must be int32_t.
* @param slot_mapping_lp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
* which has shape [batch]. Data type of slot mapping must be int32_t.
* @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table.
* @param batch_size: Batch size.
* @param head_num_q: Head number of query.
* @param head_num_kv: Head number of key and value.
* @param head_size: Head size. For simplify, the rotary dim must be the same as head_size.
* @param max_decode_len_hp: The maximum sequence length of high precision cache.
* @param max_decode_len_lp: The maximum sequence length of low precision cache.
* @param block_size_hp: Number of tokens per block of high precision cache.
* @param block_size_lp: Number of tokens per block of low precision cache.
* @param data_type: Data type of all inputs and outputs.
* @param eps: float number use for layernorm.
* @note: Head_num_q and head_num_kv must be in range [1, 32].
* Head_size must be in range [1, 128], and must be divided by 2.
*/
KernelStatus invokeFusedRope(cnrtQueue_t queue,
void *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
const void *sin_table,
const void *cos_table,
const void *rope_offsets,
const void *gamma,
const void *beta,
const void *key_scale_hp,
const void *value_scale_hp,
void *key_scale_lp,
void *value_scale_lp,
const void *cache_bs_id_hp,
const void *cache_seq_offsets_hp,
const void *cache_bs_id_lp,
const void *cache_seq_offsets_lp,
const void *slot_mapping_hp,
const void *slot_mapping_lp,
int rotary_stride,
int batch_size,
int head_num_q,
int head_num_kv,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
cnnlDataType_t dtype,
float eps);
} // namespace tmo
#endif // CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_