/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_EMBEDDING_MLUH_ #define CSRC_KERNELS_EMBEDDING_MLUH_ #include "cnnl.h" #include "kernel_utils.h" namespace tmo { /** * @brief Look up table for ids which greater than vocab_offset and less than * vocab_offset + vocab_size, and write the results back to the position * corresponding to the ids. For ids that are not in the range, write 0 * to the corresponding position. * @example * filter: * [[1, 2, 3, 4], * [5, 6, 7, 8], * [4, 3, 2, 1]] * input_ids: * [[1, 5, 6, 7, 8, 9]] * vocab_offset = 5 * vocab_size = 3 * input_size = 4 * total_seq = 6 * output: * [[0, 0, 0, 0], [1, 2, 3, 4], [5, 6, 7, 8], * [4, 3, 2, 1], [0, 0, 0, 0], [0, 0, 0, 0]] * @param queue: The queue for mlu. * @param filter: Input. Pointer to the MLU memory that stores the embedding table, * the shape must be [vocab_size, input_size]. * @param input_ids: Input. Pointer to the MLU memory that stores the token id, * the shape must be [batch, seq]. * @param output: Output. Pointer to the MLU memory that stores the output, * the shape must be [batch, seq, input_size]. * @param dtype: Data type. * @param vocab_offset: embedding table offset. * @param vocab_size: embedding table size. * @param total_vocab_size: total embedding table size. * @param input_size: embedding dim. * @param total_seq: Total sequence length. */ KernelStatus invokeEmbedding(cnrtQueue_t queue, void *filter, void *input_ids, void *output, const cnnlDataType_t dtype, int vocab_offset, int vocab_size, int total_vocab_size, int input_size, int total_seq); } // namespace tmo #endif // CSRC_KERNELS_EMBEDDING_MLUH_