120 lines
7.3 KiB
Plaintext
120 lines
7.3 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
|
|
#define CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
|
|
|
|
#include "cnnl.h"
|
|
#include "kernel_utils.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief Apply query and kery rotary embedding, key layernorm and
|
|
* quantize key and value to kv cache.
|
|
* @param queue: The queue for mlu.
|
|
* @param input: Input/Output. Pointer to the MLU memory that stores the input,
|
|
* the shape must be [batch_size, 1, head_num_q + head_num_kv * 2, head_size].
|
|
* @param key_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision key
|
|
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
|
|
* head_num_kv, block_size, head_size].
|
|
* @param value_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision
|
|
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
|
|
* head_num_kv, block_size, head_size].
|
|
* @param key_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision key
|
|
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
|
|
* head_num_kv, block_size, head_size].
|
|
* @param value_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision
|
|
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
|
|
* head_num_kv, block_size, head_size].
|
|
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
|
|
* continous. The shape must be [rotary_seq_len, rotary_dim].
|
|
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
|
|
* continous. The shape must be [rotary_seq_len, rotary_dim].
|
|
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
|
|
* batch. The shape must be [batch].
|
|
* @param norm_gamma: Input. Pointer to the MLU memory that stores the gamma param of layernorm.
|
|
* @param norm_beta: Input. Pointer to the MLU memory that stores the beta param of layernorm.
|
|
* @param key_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
|
|
* key. The shape must be [head_num_kv, head_size]. If key_scale is nullptr,
|
|
* that means key do not need to be quantized.
|
|
* @param value_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
|
|
* value. The shape must be [head_num_kv, head_size]. If value_scale is nullptr,
|
|
* that means value do not need to be quantized.
|
|
* @param key_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
|
|
* precision key. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
|
|
* [num_blocks, head_num_kv, block_size, group_num].
|
|
* @param value_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
|
|
* precision value. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
|
|
* [num_blocks, head_num_kv, block_size, group_num].
|
|
* @param cache_bs_id_hp: Input. Pointer to the MLU memory that stores the batch
|
|
* offset of high precision cache, the shape must be [batch], if it's nullptr, the
|
|
* default value is {0, 1, 2 ... batch - 1}.
|
|
* @param cache_seq_offsets_hp: Input. Pointer to the MLU memory that stores the sequence
|
|
* offset of high precision cache, the shape must be [batch].
|
|
* @param cache_bs_id_lp: Input. Pointer to the MLU memory that stores the batch
|
|
* offset of low precision cache, the shape must be [batch], if it's nullptr, the
|
|
* default value is {0, 1, 2 ... batch - 1}.
|
|
* @param cache_seq_offsets_lp: Input. Pointer to the MLU memory that stores the sequence
|
|
* offset of low precision cache, the shape must be [batch].
|
|
* @param slot_mapping_hp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
|
|
* which has shape [batch]. Data type of slot mapping must be int32_t.
|
|
* @param slot_mapping_lp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
|
|
* which has shape [batch]. Data type of slot mapping must be int32_t.
|
|
* @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table.
|
|
* @param batch_size: Batch size.
|
|
* @param head_num_q: Head number of query.
|
|
* @param head_num_kv: Head number of key and value.
|
|
* @param head_size: Head size. For simplify, the rotary dim must be the same as head_size.
|
|
* @param max_decode_len_hp: The maximum sequence length of high precision cache.
|
|
* @param max_decode_len_lp: The maximum sequence length of low precision cache.
|
|
* @param block_size_hp: Number of tokens per block of high precision cache.
|
|
* @param block_size_lp: Number of tokens per block of low precision cache.
|
|
* @param data_type: Data type of all inputs and outputs.
|
|
* @param eps: float number use for layernorm.
|
|
* @note: Head_num_q and head_num_kv must be in range [1, 32].
|
|
* Head_size must be in range [1, 128], and must be divided by 2.
|
|
*/
|
|
KernelStatus invokeFusedRope(cnrtQueue_t queue,
|
|
void *input,
|
|
void *key_cache_hp,
|
|
void *value_cache_hp,
|
|
void *key_cache_lp,
|
|
void *value_cache_lp,
|
|
const void *sin_table,
|
|
const void *cos_table,
|
|
const void *rope_offsets,
|
|
const void *gamma,
|
|
const void *beta,
|
|
const void *key_scale_hp,
|
|
const void *value_scale_hp,
|
|
void *key_scale_lp,
|
|
void *value_scale_lp,
|
|
const void *cache_bs_id_hp,
|
|
const void *cache_seq_offsets_hp,
|
|
const void *cache_bs_id_lp,
|
|
const void *cache_seq_offsets_lp,
|
|
const void *slot_mapping_hp,
|
|
const void *slot_mapping_lp,
|
|
int rotary_stride,
|
|
int batch_size,
|
|
int head_num_q,
|
|
int head_num_kv,
|
|
int head_size,
|
|
int max_decode_len_hp,
|
|
int max_decode_len_lp,
|
|
int block_size_hp,
|
|
int block_size_lp,
|
|
int group_size,
|
|
cnnlDataType_t dtype,
|
|
float eps);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
|