forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
63
torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mluh
Normal file
63
torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mluh
Normal file
@@ -0,0 +1,63 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_EMBEDDING_MLUH_
|
||||
#define CSRC_KERNELS_EMBEDDING_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Look up table for ids which greater than vocab_offset and less than
|
||||
* vocab_offset + vocab_size, and write the results back to the position
|
||||
* corresponding to the ids. For ids that are not in the range, write 0
|
||||
* to the corresponding position.
|
||||
* @example
|
||||
* filter:
|
||||
* [[1, 2, 3, 4],
|
||||
* [5, 6, 7, 8],
|
||||
* [4, 3, 2, 1]]
|
||||
* input_ids:
|
||||
* [[1, 5, 6, 7, 8, 9]]
|
||||
* vocab_offset = 5
|
||||
* vocab_size = 3
|
||||
* input_size = 4
|
||||
* total_seq = 6
|
||||
* output:
|
||||
* [[0, 0, 0, 0], [1, 2, 3, 4], [5, 6, 7, 8],
|
||||
* [4, 3, 2, 1], [0, 0, 0, 0], [0, 0, 0, 0]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param filter: Input. Pointer to the MLU memory that stores the embedding table,
|
||||
* the shape must be [vocab_size, input_size].
|
||||
* @param input_ids: Input. Pointer to the MLU memory that stores the token id,
|
||||
* the shape must be [batch, seq].
|
||||
* @param output: Output. Pointer to the MLU memory that stores the output,
|
||||
* the shape must be [batch, seq, input_size].
|
||||
* @param dtype: Data type.
|
||||
* @param vocab_offset: embedding table offset.
|
||||
* @param vocab_size: embedding table size.
|
||||
* @param total_vocab_size: total embedding table size.
|
||||
* @param input_size: embedding dim.
|
||||
* @param total_seq: Total sequence length.
|
||||
*/
|
||||
KernelStatus invokeEmbedding(cnrtQueue_t queue,
|
||||
void *filter,
|
||||
void *input_ids,
|
||||
void *output,
|
||||
const cnnlDataType_t dtype,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int total_vocab_size,
|
||||
int input_size,
|
||||
int total_seq);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_EMBEDDING_MLUH_
|
||||
Reference in New Issue
Block a user