Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mluh
2026-02-04 17:39:32 +08:00

64 lines
2.5 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_EMBEDDING_MLUH_
#define CSRC_KERNELS_EMBEDDING_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Look up table for ids which greater than vocab_offset and less than
* vocab_offset + vocab_size, and write the results back to the position
* corresponding to the ids. For ids that are not in the range, write 0
* to the corresponding position.
* @example
* filter:
* [[1, 2, 3, 4],
* [5, 6, 7, 8],
* [4, 3, 2, 1]]
* input_ids:
* [[1, 5, 6, 7, 8, 9]]
* vocab_offset = 5
* vocab_size = 3
* input_size = 4
* total_seq = 6
* output:
* [[0, 0, 0, 0], [1, 2, 3, 4], [5, 6, 7, 8],
* [4, 3, 2, 1], [0, 0, 0, 0], [0, 0, 0, 0]]
* @param queue: The queue for mlu.
* @param filter: Input. Pointer to the MLU memory that stores the embedding table,
* the shape must be [vocab_size, input_size].
* @param input_ids: Input. Pointer to the MLU memory that stores the token id,
* the shape must be [batch, seq].
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [batch, seq, input_size].
* @param dtype: Data type.
* @param vocab_offset: embedding table offset.
* @param vocab_size: embedding table size.
* @param total_vocab_size: total embedding table size.
* @param input_size: embedding dim.
* @param total_seq: Total sequence length.
*/
KernelStatus invokeEmbedding(cnrtQueue_t queue,
void *filter,
void *input_ids,
void *output,
const cnnlDataType_t dtype,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq);
} // namespace tmo
#endif // CSRC_KERNELS_EMBEDDING_MLUH_