forked from EngineX-Cambricon/enginex-mlu370-vllm
64 lines
2.5 KiB
Plaintext
64 lines
2.5 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_EMBEDDING_MLUH_
|
|
#define CSRC_KERNELS_EMBEDDING_MLUH_
|
|
|
|
#include "cnnl.h"
|
|
#include "kernel_utils.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief Look up table for ids which greater than vocab_offset and less than
|
|
* vocab_offset + vocab_size, and write the results back to the position
|
|
* corresponding to the ids. For ids that are not in the range, write 0
|
|
* to the corresponding position.
|
|
* @example
|
|
* filter:
|
|
* [[1, 2, 3, 4],
|
|
* [5, 6, 7, 8],
|
|
* [4, 3, 2, 1]]
|
|
* input_ids:
|
|
* [[1, 5, 6, 7, 8, 9]]
|
|
* vocab_offset = 5
|
|
* vocab_size = 3
|
|
* input_size = 4
|
|
* total_seq = 6
|
|
* output:
|
|
* [[0, 0, 0, 0], [1, 2, 3, 4], [5, 6, 7, 8],
|
|
* [4, 3, 2, 1], [0, 0, 0, 0], [0, 0, 0, 0]]
|
|
* @param queue: The queue for mlu.
|
|
* @param filter: Input. Pointer to the MLU memory that stores the embedding table,
|
|
* the shape must be [vocab_size, input_size].
|
|
* @param input_ids: Input. Pointer to the MLU memory that stores the token id,
|
|
* the shape must be [batch, seq].
|
|
* @param output: Output. Pointer to the MLU memory that stores the output,
|
|
* the shape must be [batch, seq, input_size].
|
|
* @param dtype: Data type.
|
|
* @param vocab_offset: embedding table offset.
|
|
* @param vocab_size: embedding table size.
|
|
* @param total_vocab_size: total embedding table size.
|
|
* @param input_size: embedding dim.
|
|
* @param total_seq: Total sequence length.
|
|
*/
|
|
KernelStatus invokeEmbedding(cnrtQueue_t queue,
|
|
void *filter,
|
|
void *input_ids,
|
|
void *output,
|
|
const cnnlDataType_t dtype,
|
|
int vocab_offset,
|
|
int vocab_size,
|
|
int total_vocab_size,
|
|
int input_size,
|
|
int total_seq);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_EMBEDDING_MLUH_
|