/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_ #define CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_ #include "cnnl.h" #include "kernel_utils.h" namespace tmo { /** * @brief In the context stage, concate the result of multi head attention * key and value to key_cache and value_cache. * @example * input: * cache: * [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], * [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], * [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * context: * [[1, 2, 3, 4, 5], * [6, 7, 8, 9, 10]] * cache_bs_offsets: [1, 2] * cache_seq_offsets: [3, 4] * context_seq_offsets: [0, 1] * context_lens: [4, 3] * output: * [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], * [0, 0, 0, 1, 2, 3, 4, 0, 0, 0], * [0, 0, 0, 0, 7, 8, 9, 0, 0, 0]] * @param queue: The queue for mlu. * @param key_cache: Pointer to the MLU memory that stores the key cache, * the shape must be [max_batch, head_num, cache_mem_len, head_size]. * key_cache could be nullptr. * @param value_cache: Pointer to the MLU memory that stores the value cache, * the shape must be [max_batch, head_num, cache_mem_len, head_size]. * value_cache could be nullptr. * @param cache_bs_offsets: Pointer to the MLU memory that stores the batch * offset of cache, the shape must be [batch], if it's nullptr, the * default value is {0, 1, 2 ... batch - 1}. * @param cache_seq_offsets: Input. Pointer to the MLU memory that stores the sequence * offset of cache, the shape must be [batch], if it's nullptr, the * default value is 0 for every batch. * @param key: Pointer to the MLU memory that stores the key, * the shape must be [batch, max_contxt_len, head_num, head_size]. * key could be nullptr. * @param value: Pointer to the MLU memory that stores the value, * the shape must be [batch, max_contxt_len, head_num, head_size]. * value could be nullptr. * @param context_seq_offsets: Pointer to the MLU memory that stores the * sequence offset of context, the shape must be [batch]. if it's nullptr, * the default value is 0 for every batch. It must be nullptr when packed is true. * @param context_lens: Input. Pointer to the MLU memory that stores the sequence length or * cumulative sequence length of context. when packed is false, the shape must be [batch], which * indicates sequence length of context. when packed is true, the shape must be [batch + 1], which * indicates cumulative sequence length of context. * @param dtype: Data type. * @param batch: Batch size. * @param head_num: Head number. * @param head_size: Head size. * @param max_contxt_len: The maximum sequence length of context. * @param cache_mem_len: The maximum sequence length of cache. * @param contxt_bs_stride: The stride of batch in context, does not work when packed is true. * @param contxt_head_stride: The stride of head_num in context. * @param contxt_seq_stride: The stride of max_contxt_len in context. * @param cache_bs_stride: The stride of batch in cache. * @param cache_head_stride: The stride of head_num in cache. * @param cache_seq_stride: The stride of cache_mem_len in cache. * @param packed: A boolean value indicates whether to use pack mode. * @note If key and key_cache are nullptr, nothing todo for key. If value and value_cache are nullptr, nothing todo for value. A negative value in cache_bs_offsets or cache_seq_offsets means nothing to do for the corresponding batch. */ KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue, void *key_cache, void *value_cache, const void *cache_bs_offsets, const void *cache_seq_offsets, void *key, void *value, const void *context_seq_offsets, const void *context_lens, const cnnlDataType_t dtype, const int batch, const int head_num, const int head_size, const int max_context_len, const int cache_mem_len, const int context_bs_stride, const int context_head_stride, const int context_seq_stride, const int cache_bs_stride, const int cache_head_stride, const int cache_seq_stride, const bool packed); } // namespace tmo #endif // CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_