/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_ #define CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_ #include "cnnl.h" #include "kernel_utils.h" namespace tmo { /** * @brief Apply query and kery rotary embedding, key layernorm and * quantize key and value to kv cache. * @param queue: The queue for mlu. * @param input: Input/Output. Pointer to the MLU memory that stores the input, * the shape must be [batch_size, 1, head_num_q + head_num_kv * 2, head_size]. * @param key_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision key * cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks, * head_num_kv, block_size, head_size]. * @param value_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision * value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks, * head_num_kv, block_size, head_size]. * @param key_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision key * cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks, * head_num_kv, block_size, head_size]. * @param value_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision * value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks, * head_num_kv, block_size, head_size]. * @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be * continous. The shape must be [rotary_seq_len, rotary_dim]. * @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be * continous. The shape must be [rotary_seq_len, rotary_dim]. * @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each * batch. The shape must be [batch]. * @param norm_gamma: Input. Pointer to the MLU memory that stores the gamma param of layernorm. * @param norm_beta: Input. Pointer to the MLU memory that stores the beta param of layernorm. * @param key_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision * key. The shape must be [head_num_kv, head_size]. If key_scale is nullptr, * that means key do not need to be quantized. * @param value_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision * value. The shape must be [head_num_kv, head_size]. If value_scale is nullptr, * that means value do not need to be quantized. * @param key_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low * precision key. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or * [num_blocks, head_num_kv, block_size, group_num]. * @param value_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low * precision value. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or * [num_blocks, head_num_kv, block_size, group_num]. * @param cache_bs_id_hp: Input. Pointer to the MLU memory that stores the batch * offset of high precision cache, the shape must be [batch], if it's nullptr, the * default value is {0, 1, 2 ... batch - 1}. * @param cache_seq_offsets_hp: Input. Pointer to the MLU memory that stores the sequence * offset of high precision cache, the shape must be [batch]. * @param cache_bs_id_lp: Input. Pointer to the MLU memory that stores the batch * offset of low precision cache, the shape must be [batch], if it's nullptr, the * default value is {0, 1, 2 ... batch - 1}. * @param cache_seq_offsets_lp: Input. Pointer to the MLU memory that stores the sequence * offset of low precision cache, the shape must be [batch]. * @param slot_mapping_hp: Input. Pointer to the MLU memory that stores the slot_mapping tensor * which has shape [batch]. Data type of slot mapping must be int32_t. * @param slot_mapping_lp: Input. Pointer to the MLU memory that stores the slot_mapping tensor * which has shape [batch]. Data type of slot mapping must be int32_t. * @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table. * @param batch_size: Batch size. * @param head_num_q: Head number of query. * @param head_num_kv: Head number of key and value. * @param head_size: Head size. For simplify, the rotary dim must be the same as head_size. * @param max_decode_len_hp: The maximum sequence length of high precision cache. * @param max_decode_len_lp: The maximum sequence length of low precision cache. * @param block_size_hp: Number of tokens per block of high precision cache. * @param block_size_lp: Number of tokens per block of low precision cache. * @param data_type: Data type of all inputs and outputs. * @param eps: float number use for layernorm. * @note: Head_num_q and head_num_kv must be in range [1, 32]. * Head_size must be in range [1, 128], and must be divided by 2. */ KernelStatus invokeFusedRope(cnrtQueue_t queue, void *input, void *key_cache_hp, void *value_cache_hp, void *key_cache_lp, void *value_cache_lp, const void *sin_table, const void *cos_table, const void *rope_offsets, const void *gamma, const void *beta, const void *key_scale_hp, const void *value_scale_hp, void *key_scale_lp, void *value_scale_lp, const void *cache_bs_id_hp, const void *cache_seq_offsets_hp, const void *cache_bs_id_lp, const void *cache_seq_offsets_lp, const void *slot_mapping_hp, const void *slot_mapping_lp, int rotary_stride, int batch_size, int head_num_q, int head_num_kv, int head_size, int max_decode_len_hp, int max_decode_len_lp, int block_size_hp, int block_size_lp, int group_size, cnnlDataType_t dtype, float eps); } // namespace tmo #endif // CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_