Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mluh
2026-02-04 17:39:32 +08:00

107 lines
5.5 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
#define CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief In the context stage, concate the result of multi head attention
* key and value to key_cache and value_cache.
* @example
* input:
* cache:
* [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
* [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
* [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
* context:
* [[1, 2, 3, 4, 5],
* [6, 7, 8, 9, 10]]
* cache_bs_offsets: [1, 2]
* cache_seq_offsets: [3, 4]
* context_seq_offsets: [0, 1]
* context_lens: [4, 3]
* output:
* [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
* [0, 0, 0, 1, 2, 3, 4, 0, 0, 0],
* [0, 0, 0, 0, 7, 8, 9, 0, 0, 0]]
* @param queue: The queue for mlu.
* @param key_cache: Pointer to the MLU memory that stores the key cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* key_cache could be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* value_cache could be nullptr.
* @param cache_bs_offsets: Pointer to the MLU memory that stores the batch
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets: Input. Pointer to the MLU memory that stores the sequence
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is 0 for every batch.
* @param key: Pointer to the MLU memory that stores the key,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* key could be nullptr.
* @param value: Pointer to the MLU memory that stores the value,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* value could be nullptr.
* @param context_seq_offsets: Pointer to the MLU memory that stores the
* sequence offset of context, the shape must be [batch]. if it's nullptr,
* the default value is 0 for every batch. It must be nullptr when packed is true.
* @param context_lens: Input. Pointer to the MLU memory that stores the sequence length or
* cumulative sequence length of context. when packed is false, the shape must be [batch], which
* indicates sequence length of context. when packed is true, the shape must be [batch + 1], which
* indicates cumulative sequence length of context.
* @param dtype: Data type.
* @param batch: Batch size.
* @param head_num: Head number.
* @param head_size: Head size.
* @param max_contxt_len: The maximum sequence length of context.
* @param cache_mem_len: The maximum sequence length of cache.
* @param contxt_bs_stride: The stride of batch in context, does not work when packed is true.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param cache_seq_stride: The stride of cache_mem_len in cache.
* @param packed: A boolean value indicates whether to use pack mode.
* @note If key and key_cache are nullptr, nothing todo for key.
If value and value_cache are nullptr, nothing todo for value.
A negative value in cache_bs_offsets or cache_seq_offsets means nothing to do for
the corresponding batch.
*/
KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
void *key,
void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const int context_bs_stride,
const int context_head_stride,
const int context_seq_stride,
const int cache_bs_stride,
const int cache_head_stride,
const int cache_seq_stride,
const bool packed);
} // namespace tmo
#endif // CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_