107 lines
5.5 KiB
Plaintext
107 lines
5.5 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
|
|
#define CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
|
|
|
|
#include "cnnl.h"
|
|
#include "kernel_utils.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief In the context stage, concate the result of multi head attention
|
|
* key and value to key_cache and value_cache.
|
|
* @example
|
|
* input:
|
|
* cache:
|
|
* [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
* [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
* [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
|
* context:
|
|
* [[1, 2, 3, 4, 5],
|
|
* [6, 7, 8, 9, 10]]
|
|
* cache_bs_offsets: [1, 2]
|
|
* cache_seq_offsets: [3, 4]
|
|
* context_seq_offsets: [0, 1]
|
|
* context_lens: [4, 3]
|
|
* output:
|
|
* [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
* [0, 0, 0, 1, 2, 3, 4, 0, 0, 0],
|
|
* [0, 0, 0, 0, 7, 8, 9, 0, 0, 0]]
|
|
* @param queue: The queue for mlu.
|
|
* @param key_cache: Pointer to the MLU memory that stores the key cache,
|
|
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
|
|
* key_cache could be nullptr.
|
|
* @param value_cache: Pointer to the MLU memory that stores the value cache,
|
|
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
|
|
* value_cache could be nullptr.
|
|
* @param cache_bs_offsets: Pointer to the MLU memory that stores the batch
|
|
* offset of cache, the shape must be [batch], if it's nullptr, the
|
|
* default value is {0, 1, 2 ... batch - 1}.
|
|
* @param cache_seq_offsets: Input. Pointer to the MLU memory that stores the sequence
|
|
* offset of cache, the shape must be [batch], if it's nullptr, the
|
|
* default value is 0 for every batch.
|
|
* @param key: Pointer to the MLU memory that stores the key,
|
|
* the shape must be [batch, max_contxt_len, head_num, head_size].
|
|
* key could be nullptr.
|
|
* @param value: Pointer to the MLU memory that stores the value,
|
|
* the shape must be [batch, max_contxt_len, head_num, head_size].
|
|
* value could be nullptr.
|
|
* @param context_seq_offsets: Pointer to the MLU memory that stores the
|
|
* sequence offset of context, the shape must be [batch]. if it's nullptr,
|
|
* the default value is 0 for every batch. It must be nullptr when packed is true.
|
|
* @param context_lens: Input. Pointer to the MLU memory that stores the sequence length or
|
|
* cumulative sequence length of context. when packed is false, the shape must be [batch], which
|
|
* indicates sequence length of context. when packed is true, the shape must be [batch + 1], which
|
|
* indicates cumulative sequence length of context.
|
|
* @param dtype: Data type.
|
|
* @param batch: Batch size.
|
|
* @param head_num: Head number.
|
|
* @param head_size: Head size.
|
|
* @param max_contxt_len: The maximum sequence length of context.
|
|
* @param cache_mem_len: The maximum sequence length of cache.
|
|
* @param contxt_bs_stride: The stride of batch in context, does not work when packed is true.
|
|
* @param contxt_head_stride: The stride of head_num in context.
|
|
* @param contxt_seq_stride: The stride of max_contxt_len in context.
|
|
* @param cache_bs_stride: The stride of batch in cache.
|
|
* @param cache_head_stride: The stride of head_num in cache.
|
|
* @param cache_seq_stride: The stride of cache_mem_len in cache.
|
|
* @param packed: A boolean value indicates whether to use pack mode.
|
|
* @note If key and key_cache are nullptr, nothing todo for key.
|
|
If value and value_cache are nullptr, nothing todo for value.
|
|
A negative value in cache_bs_offsets or cache_seq_offsets means nothing to do for
|
|
the corresponding batch.
|
|
*/
|
|
KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue,
|
|
void *key_cache,
|
|
void *value_cache,
|
|
const void *cache_bs_offsets,
|
|
const void *cache_seq_offsets,
|
|
void *key,
|
|
void *value,
|
|
const void *context_seq_offsets,
|
|
const void *context_lens,
|
|
const cnnlDataType_t dtype,
|
|
const int batch,
|
|
const int head_num,
|
|
const int head_size,
|
|
const int max_context_len,
|
|
const int cache_mem_len,
|
|
const int context_bs_stride,
|
|
const int context_head_stride,
|
|
const int context_seq_stride,
|
|
const int cache_bs_stride,
|
|
const int cache_head_stride,
|
|
const int cache_seq_stride,
|
|
const bool packed);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
|