forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
106
torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mluh
Normal file
106
torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mluh
Normal file
@@ -0,0 +1,106 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
|
||||
#define CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief In the context stage, concate the result of multi head attention
|
||||
* key and value to key_cache and value_cache.
|
||||
* @example
|
||||
* input:
|
||||
* cache:
|
||||
* [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
* [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
* [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
* context:
|
||||
* [[1, 2, 3, 4, 5],
|
||||
* [6, 7, 8, 9, 10]]
|
||||
* cache_bs_offsets: [1, 2]
|
||||
* cache_seq_offsets: [3, 4]
|
||||
* context_seq_offsets: [0, 1]
|
||||
* context_lens: [4, 3]
|
||||
* output:
|
||||
* [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
* [0, 0, 0, 1, 2, 3, 4, 0, 0, 0],
|
||||
* [0, 0, 0, 0, 7, 8, 9, 0, 0, 0]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param key_cache: Pointer to the MLU memory that stores the key cache,
|
||||
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
|
||||
* key_cache could be nullptr.
|
||||
* @param value_cache: Pointer to the MLU memory that stores the value cache,
|
||||
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
|
||||
* value_cache could be nullptr.
|
||||
* @param cache_bs_offsets: Pointer to the MLU memory that stores the batch
|
||||
* offset of cache, the shape must be [batch], if it's nullptr, the
|
||||
* default value is {0, 1, 2 ... batch - 1}.
|
||||
* @param cache_seq_offsets: Input. Pointer to the MLU memory that stores the sequence
|
||||
* offset of cache, the shape must be [batch], if it's nullptr, the
|
||||
* default value is 0 for every batch.
|
||||
* @param key: Pointer to the MLU memory that stores the key,
|
||||
* the shape must be [batch, max_contxt_len, head_num, head_size].
|
||||
* key could be nullptr.
|
||||
* @param value: Pointer to the MLU memory that stores the value,
|
||||
* the shape must be [batch, max_contxt_len, head_num, head_size].
|
||||
* value could be nullptr.
|
||||
* @param context_seq_offsets: Pointer to the MLU memory that stores the
|
||||
* sequence offset of context, the shape must be [batch]. if it's nullptr,
|
||||
* the default value is 0 for every batch. It must be nullptr when packed is true.
|
||||
* @param context_lens: Input. Pointer to the MLU memory that stores the sequence length or
|
||||
* cumulative sequence length of context. when packed is false, the shape must be [batch], which
|
||||
* indicates sequence length of context. when packed is true, the shape must be [batch + 1], which
|
||||
* indicates cumulative sequence length of context.
|
||||
* @param dtype: Data type.
|
||||
* @param batch: Batch size.
|
||||
* @param head_num: Head number.
|
||||
* @param head_size: Head size.
|
||||
* @param max_contxt_len: The maximum sequence length of context.
|
||||
* @param cache_mem_len: The maximum sequence length of cache.
|
||||
* @param contxt_bs_stride: The stride of batch in context, does not work when packed is true.
|
||||
* @param contxt_head_stride: The stride of head_num in context.
|
||||
* @param contxt_seq_stride: The stride of max_contxt_len in context.
|
||||
* @param cache_bs_stride: The stride of batch in cache.
|
||||
* @param cache_head_stride: The stride of head_num in cache.
|
||||
* @param cache_seq_stride: The stride of cache_mem_len in cache.
|
||||
* @param packed: A boolean value indicates whether to use pack mode.
|
||||
* @note If key and key_cache are nullptr, nothing todo for key.
|
||||
If value and value_cache are nullptr, nothing todo for value.
|
||||
A negative value in cache_bs_offsets or cache_seq_offsets means nothing to do for
|
||||
the corresponding batch.
|
||||
*/
|
||||
KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue,
|
||||
void *key_cache,
|
||||
void *value_cache,
|
||||
const void *cache_bs_offsets,
|
||||
const void *cache_seq_offsets,
|
||||
void *key,
|
||||
void *value,
|
||||
const void *context_seq_offsets,
|
||||
const void *context_lens,
|
||||
const cnnlDataType_t dtype,
|
||||
const int batch,
|
||||
const int head_num,
|
||||
const int head_size,
|
||||
const int max_context_len,
|
||||
const int cache_mem_len,
|
||||
const int context_bs_stride,
|
||||
const int context_head_stride,
|
||||
const int context_seq_stride,
|
||||
const int cache_bs_stride,
|
||||
const int cache_head_stride,
|
||||
const int cache_seq_stride,
|
||||
const bool packed);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_
|
||||
Reference in New Issue
Block a user