v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

338
csrc/camem_allocator.cpp Normal file
View File

@@ -0,0 +1,338 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
extern "C" {
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include "acl/acl.h"
// Global references to Python callables
// NOTE: this is borrowed reference, so we don't need to DECREF them.
// This brings the limitation that the allocator needs to be singleton.
static PyObject* g_python_malloc_callback = nullptr;
static PyObject* g_python_free_callback = nullptr;
// ---------------------------------------------------------------------------
// Helper functions:
void ensure_context(unsigned long long device) {
aclrtContext pctx;
aclrtGetCurrentContext(&pctx);
if (!pctx) {
// Ensure device context.
aclrtCreateContext(&pctx, device);
aclrtSetCurrentContext(pctx);
}
}
void create_and_map(unsigned long long device, ssize_t size, void* d_mem,
aclrtDrvMemHandle* p_memHandle) {
ensure_context(device);
// Define memory allocation properties
aclrtPhysicalMemProp prop = {};
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE ;
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
prop.memAttr = ACL_HBM_MEM_HUGE;
prop.location.id = device;
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
prop.reserve = 0;
// Allocate memory using aclrtMallocPhysical
aclError error_code = aclrtMallocPhysical(p_memHandle, size, &prop, 0);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
error_code = aclrtMapMem(d_mem, size, 0, *p_memHandle, 0);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
}
void unmap_and_release(unsigned long long device, ssize_t size,
void* d_mem,
aclrtDrvMemHandle* p_memHandle) {
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context(device);
aclError error_code = aclrtUnmapMem(d_mem);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
error_code = aclrtFreePhysical(*p_memHandle);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
}
PyObject* create_tuple_from_c_integers(unsigned long long a,
unsigned long long b,
unsigned long long c,
unsigned long long d) {
// Create a new tuple of size 4
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL; // Return NULL on failure
}
// Convert integers to Python objects and set them in the tuple
PyTuple_SetItem(
tuple, 0,
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
// Note: PyTuple_SetItem "steals" a reference to each object,
// so we do not need to Py_DECREF the PyLong objects explicitly.
return tuple; // Return the created tuple
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
__attribute__ ((visibility("default"))) void* my_malloc(ssize_t size, int device, aclrtStream stream) {
ensure_context(device);
// first allocation, align the size, and reserve an address, and also allocate
// a aclrtDrvMemHandle
// Define memory allocation properties
aclrtPhysicalMemProp prop = {};
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE ;
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
prop.memAttr = ACL_HBM_MEM_HUGE;
prop.location.id = device;
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
prop.reserve = 0;
// Check if the allocation is supported
size_t granularity;
aclError error_code = aclrtMemGetAllocationGranularity(&prop,
ACL_RT_MEM_ALLOC_GRANULARITY_MINIMUM,
&granularity);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return nullptr;
}
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
void *d_mem;
error_code = aclrtReserveMemAddress(&d_mem, alignedSize, 0, nullptr, 0);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return nullptr;
}
// allocate the aclrtDrvMemHandle
aclrtDrvMemHandle* p_memHandle =
(aclrtDrvMemHandle*)malloc(sizeof(aclrtDrvMemHandle));
if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
return nullptr;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
// Call g_python_malloc_callback
PyObject* py_result =
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
Py_DECREF(arg_tuple);
if (!py_result) {
PyErr_Print();
PyGILState_Release(gstate);
return nullptr;
}
PyGILState_Release(gstate);
// do the final mapping
create_and_map(device, alignedSize, d_mem, p_memHandle);
return (void*)d_mem;
}
__attribute__ ((visibility("default"))) void my_free(void* ptr, ssize_t size, int device, aclrtStream stream) {
// get memory handle from the pointer
if (!g_python_free_callback) {
std::cerr << "ERROR: g_python_free_callback not set.\n";
return;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* py_ptr =
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
PyObject* py_result =
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return;
}
PyGILState_Release(gstate);
// recv_size == size
// recv_device == device
// Free memory
void *d_mem = (void*)recv_d_mem;
// allocate the aclrtDrvMemHandle
aclrtDrvMemHandle* p_memHandle =
(aclrtDrvMemHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
// free address and the handle
aclError error_code = aclrtReleaseMemAddress(d_mem);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
free(p_memHandle);
}
// ---------------------------------------------------------------------------
// Python extension boilerplate:
// Python-exposed function: init_module(python_malloc, python_free)
static PyObject* py_init_module(PyObject* self, PyObject* args) {
PyObject* malloc_callback = nullptr;
PyObject* free_callback = nullptr;
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
return nullptr;
}
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
return nullptr;
}
// Save the Python callables
// This module does not handle GC of these objects, so they must be kept alive
// outside of this module.
g_python_malloc_callback = malloc_callback;
g_python_free_callback = free_callback;
Py_RETURN_NONE;
}
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
void *d_mem_ptr = (void*)recv_d_mem;
aclrtDrvMemHandle* p_memHandle =
(aclrtDrvMemHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
void *d_mem_ptr = (void*)recv_d_mem;
aclrtDrvMemHandle* p_memHandle =
(aclrtDrvMemHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyMethodDef module_methods[] = {
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
"Initialize module with python_malloc and python_free callables."},
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
"Create and map memory on the device."},
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
METH_VARARGS, "Unmap and release memory on the device."},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef camem_allocator_module = {
PyModuleDef_HEAD_INIT, "camem_allocator",
"CANN-mem-based allocator for NPUPluggableAllocator", -1, module_methods};
PyMODINIT_FUNC PyInit_vllm_ascend_C(void) {
// Initialize the module
PyObject* module = PyModule_Create(&camem_allocator_module);
if (!module) {
return NULL;
}
return module;
}
} // extern "C"

View File

@@ -0,0 +1,369 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include "types.h"
template <typename scalar_t>
class BGMVExpand {
public:
using X_T = float;
using W_T = scalar_t;
using Y_T = scalar_t;
static constexpr uint64_t LORA_RANK_8 = 8;
static constexpr uint64_t LORA_RANK_16 = 16;
static constexpr uint64_t LORA_RANK_32 = 32;
static constexpr uint64_t LORA_RANK_64 = 64;
static constexpr uint64_t SUPPORTED_RANKS[] = {LORA_RANK_8, LORA_RANK_16, LORA_RANK_32, LORA_RANK_64};
static constexpr int32_t BUFFER_NUM = 2;
// The vector unit reads 8 blocks (32 bytes each and 256 bytes in total) of contiguous data each time.
static constexpr int32_t NUM_BYTES_PER_REPEAT = 256;
static constexpr int32_t NUM_BLOCKS_PER_REPEAT = 8;
// The maximum number of elements in a single iteration is 256 / sizeof(intermediate data type).
static constexpr int32_t NUM_ELEMENTS_PER_REPEAT = NUM_BYTES_PER_REPEAT / sizeof(float);
// Mask is used to control the elements that participate in computation in each iteration.
static constexpr int32_t MASK_COUNT = NUM_BYTES_PER_REPEAT / sizeof(float);
// Refer to numOutputElementsPerInputTile_ initialization for the constraints on the following constants.
static constexpr int32_t W_IN_TILE_NUM_ELEMENTS = 8192;
static constexpr int32_t Y_OUT_TILE_NUM_ELEMENTS = 4096;
static constexpr int32_t BLOCK_REDUCE_NUM_REPEATS = W_IN_TILE_NUM_ELEMENTS / NUM_ELEMENTS_PER_REPEAT;
// BlockReduceSum would generate(BLOCK_REDUCE_NUM_REPEATS * NUM_BLOCKS_PER_REPEAT)floats.
// So need to read them all and apply PairReduceSum
static constexpr int32_t PAIR_REDUCE_NUM_REPEATS_16 =
(BLOCK_REDUCE_NUM_REPEATS * NUM_BLOCKS_PER_REPEAT + NUM_ELEMENTS_PER_REPEAT - 1) / NUM_ELEMENTS_PER_REPEAT;
// The second PairReduceSum for rank=32, needs half of the repetition that happened for rank=16.
// Same for rank=64, we do not support ranks greater than 64.
static constexpr int32_t PAIR_REDUCE_NUM_REPEATS_32 = (PAIR_REDUCE_NUM_REPEATS_16 + 1) / 2;
public:
__aicore__ inline BGMVExpand(AscendC::TPipe* pipe) : pipe_(pipe) {}
__aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,
uint32_t indicesSize, __gm__ void* yIn, __gm__ void* yOut,
uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
{
batchSize_ = batchSize;
numTokensPerCore_ = numTokensPerCore;
maxLoRARank_ = maxLoRARank;
outputHiddenDim_ = outputHiddenDim;
sliceOffset_ = sliceOffset;
outputFullDim_ = outputFullDim;
singleLoRAWeightLen_ = maxLoRARank_ * outputHiddenDim_;
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
yInGm_.SetGlobalBuffer((__gm__ Y_T *)yIn);
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)yOut);
indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices, indicesSize);
pipe_->InitBuffer(inQueueX_, 1, NUM_ELEMENTS_PER_REPEAT * sizeof(X_T));
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, W_IN_TILE_NUM_ELEMENTS * sizeof(W_T));
pipe_->InitBuffer(inQueueY_, BUFFER_NUM, Y_OUT_TILE_NUM_ELEMENTS * sizeof(Y_T));
pipe_->InitBuffer(outQueueY_, BUFFER_NUM, Y_OUT_TILE_NUM_ELEMENTS * sizeof(Y_T));
pipe_->InitBuffer(dupBufferX_, NUM_ELEMENTS_PER_REPEAT * sizeof(float));
pipe_->InitBuffer(tmpBufferW_, W_IN_TILE_NUM_ELEMENTS * sizeof(float));
pipe_->InitBuffer(inBufferY_, Y_OUT_TILE_NUM_ELEMENTS * sizeof(float));
pipe_->InitBuffer(tmpBufferY_, Y_OUT_TILE_NUM_ELEMENTS * sizeof(float));
// Each compute iteration would generate not one, but several output elements.
// Therefore, the following variable would determine how many output elements are calculated in each iteration.
numOutputElementsPerInputTile_ = BLOCK_REDUCE_NUM_REPEATS * (NUM_ELEMENTS_PER_REPEAT / maxLoRARank_);
numStreamInPerOutputTile_ = Y_OUT_TILE_NUM_ELEMENTS / numOutputElementsPerInputTile_;
}
__aicore__ inline void Process()
{
int64_t blockIdx = AscendC::GetBlockIdx();
int64_t startIdx = blockIdx * numTokensPerCore_;
int64_t endIdx = startIdx + numTokensPerCore_;
if (endIdx > batchSize_) {
endIdx = batchSize_;
}
for (int64_t idx = startIdx; idx < endIdx; idx++) {
yOffset_ = outputFullDim_ * idx + sliceOffset_;
// Set up LoRA index
CopyInIndex(idx);
if (reqLoRAIndex_ < 0) {
continue;
}
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
CopyInX(idx);
int32_t numStreamOut = outputHiddenDim_ / Y_OUT_TILE_NUM_ELEMENTS;
for (int32_t i = 0; i < numStreamOut; i++) {
CopyInY(i);
for (int32_t j = 0; j < numStreamInPerOutputTile_; j++) {
CopyInW(i * numStreamInPerOutputTile_ + j);
Compute(j * numOutputElementsPerInputTile_);
}
ScaleOutput();
CopyOut(i);
}
ComputeLastIteration();
}
}
private:
__aicore__ inline void CopyInIndex(const int64_t idx)
{
// Look up the LoRA index
reqLoRAIndex_ = indicesGm_.GetValue(idx);
}
__aicore__ inline void ComputeLastIteration()
{
int32_t remainingY = outputHiddenDim_ % Y_OUT_TILE_NUM_ELEMENTS;
if (remainingY == 0) {
return;
}
int32_t numStreamOut = outputHiddenDim_ / Y_OUT_TILE_NUM_ELEMENTS;
int32_t remainingW = remainingY * maxLoRARank_;
int32_t numCompleteWTileInForLastIteration = remainingW / W_IN_TILE_NUM_ELEMENTS;
int32_t remainingWForLastRepeat = remainingW % W_IN_TILE_NUM_ELEMENTS;
CopyInY(numStreamOut, remainingY);
int32_t outputIdx = 0;
for (outputIdx = 0; outputIdx < numCompleteWTileInForLastIteration; outputIdx++) {
CopyInW(numStreamOut * numStreamInPerOutputTile_ + outputIdx);
Compute(outputIdx * numOutputElementsPerInputTile_);
}
if (remainingWForLastRepeat != 0) {
CopyInW(numStreamOut * numStreamInPerOutputTile_ + numCompleteWTileInForLastIteration,
remainingWForLastRepeat);
int32_t lastRepeatCount = remainingWForLastRepeat / NUM_ELEMENTS_PER_REPEAT;
int32_t pairReduceRepeat16 =
(lastRepeatCount * NUM_BLOCKS_PER_REPEAT + NUM_ELEMENTS_PER_REPEAT - 1) / NUM_ELEMENTS_PER_REPEAT;
int32_t pairReduceRepeat32 = (pairReduceRepeat16 + 1) / 2;
int32_t lastComputeOutputElement = outputIdx * numOutputElementsPerInputTile_;
Compute(lastComputeOutputElement, lastRepeatCount, pairReduceRepeat16, pairReduceRepeat32);
}
ScaleOutput(remainingY);
CopyOut(numStreamOut, remainingY);
}
__aicore__ inline void CopyInX(const int64_t idx)
{
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
if constexpr (std::is_same_v<X_T, float>) {
DataCopy(xLocal, xGm_[maxLoRARank_ * idx], maxLoRARank_);
} else {
uint16_t blockLen = static_cast<uint16_t>(maxLoRARank_ * sizeof(X_T));
DataCopyPad(xLocal, xGm_[maxLoRARank_ * idx], {1, blockLen, 0, 0}, {});
}
inQueueX_.EnQue(xLocal);
xLocal = inQueueX_.DeQue<X_T>();
AscendC::LocalTensor<float> xDup = dupBufferX_.Get<float>();
// As we are generating multiple output elements with one API invocation,
// we need to duplicate the X vector multiple times to fill one NUM_BYTES_PER_REPEAT
if constexpr (std::is_same_v<X_T, float>) {
for (int32_t i = 0; i < NUM_ELEMENTS_PER_REPEAT; i += maxLoRARank_) {
for (int32_t j = 0; j < maxLoRARank_; j++) {
float entry = xLocal.GetValue(j);
xDup.SetValue(i + j, entry);
}
}
} else {
Cast(xDup, xLocal, AscendC::RoundMode::CAST_NONE, maxLoRARank_);
pipe_barrier(PIPE_V);
for (int32_t i = maxLoRARank_; i < NUM_ELEMENTS_PER_REPEAT; i += maxLoRARank_) {
for (int32_t j = 0; j < maxLoRARank_; j++) {
float entry = xDup.GetValue(j);
xDup.SetValue(i + j, entry);
}
}
}
inQueueX_.FreeTensor(xLocal);
}
__aicore__ inline void CopyInY(int32_t progress, int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<Y_T> yInLocal = inQueueY_.AllocTensor<Y_T>();
DataCopy(yInLocal, yInGm_[yOffset_ + progress * Y_OUT_TILE_NUM_ELEMENTS], numElements);
inQueueY_.EnQue(yInLocal);
}
__aicore__ inline void CopyInW(int32_t progress, int32_t numElements = W_IN_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + progress * W_IN_TILE_NUM_ELEMENTS], numElements);
inQueueW_.EnQue(wLocal);
}
__aicore__ inline void ScaleOutput(int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<float> yLocal = tmpBufferY_.Get<float>();
AscendC::LocalTensor<Y_T> yInLocal = inQueueY_.DeQue<Y_T>();
AscendC::LocalTensor<float> yInLocalFP32 = inBufferY_.Get<float>();
Cast(yInLocalFP32, yInLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueY_.FreeTensor(yInLocal);
Add(yLocal, yLocal, yInLocalFP32, numElements);
pipe_barrier(PIPE_V);
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
Cast(yOutLocal, yLocal, AscendC::RoundMode::CAST_RINT, numElements);
pipe_barrier(PIPE_V);
outQueueY_.EnQue<Y_T>(yOutLocal);
}
__aicore__ inline void Compute(int32_t progress,
int32_t blockReduceRepeatCount=BLOCK_REDUCE_NUM_REPEATS,
int32_t pairReduceRepeat16=PAIR_REDUCE_NUM_REPEATS_16,
int32_t pairReduceRepeat32=PAIR_REDUCE_NUM_REPEATS_32)
{
AscendC::LocalTensor<float> yLocal = tmpBufferY_.Get<float>();
AscendC::LocalTensor<float> xDup = dupBufferX_.Get<float>();
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, MASK_COUNT, blockReduceRepeatCount, castParams_);
pipe_barrier(PIPE_V);
inQueueW_.FreeTensor(wLocal);
Mul(wTmpTensor, xDup, wTmpTensor, MASK_COUNT, blockReduceRepeatCount, dotProductParams_);
pipe_barrier(PIPE_V);
if (maxLoRARank_ == LORA_RANK_8) {
BlockReduceSum(yLocal[progress], wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_16) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_32) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(wTmpTensor, wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat32, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_64) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
BlockReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
}
}
__aicore__ inline void CopyOut(int32_t progress, int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
DataCopy(yOutGm_[yOffset_ + progress * Y_OUT_TILE_NUM_ELEMENTS], yOutLocal, numElements);
outQueueY_.FreeTensor(yOutLocal);
}
private:
AscendC::TPipe* pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueY_, inQueueW_;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueX_;
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> outQueueY_;
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferW_, dupBufferX_, inBufferY_, tmpBufferY_;
AscendC::GlobalTensor<X_T> xGm_;
AscendC::GlobalTensor<W_T> wGm_;
AscendC::GlobalTensor<Y_T> yInGm_;
AscendC::GlobalTensor<Y_T> yOutGm_;
AscendC::GlobalTensor<int64_t> indicesGm_;
uint32_t batchSize_;
uint32_t numTokensPerCore_;
uint32_t maxLoRARank_;
uint32_t outputHiddenDim_;
uint32_t sliceOffset_;
uint32_t outputFullDim_;
uint32_t singleLoRAWeightLen_;
int64_t reqLoRAIndex_;
uint64_t reqLoRAWeightOffset_;
uint32_t numOutputElementsPerInputTile_;
uint32_t numStreamInPerOutputTile_;
uint64_t yOffset_;
// The block stride is set to 1, and 8 blocks in the same repeat are processed continuously.
// The repeat stride is 8, so the vector unit reads 8 consecutive blocks in the first repeat,
// reads next 8 consecutive blocks in the second repeat.
AscendC::UnaryRepeatParams castParams_ = {1, 1, 8, 4};
// For each repeat in BlockReduceSum and PairReduceSum we should move forward only one block,
// so we set dstRepStride = 1
AscendC::UnaryRepeatParams reduceSumParams_ = {1, 1, 1, 8};
// When the repeat stride is 0, the vector unit repeatedly reads and computes the first 8 consecutive blocks.
// For xDup we repeatedly use it, so we set src0RepStride = 0
AscendC::BinaryRepeatParams dotProductParams_ = {1, 1, 1, 8, 0, 8};
};
#define BGMV_EXPAND_TYPE_DECLARE(TYPE) \
extern "C" __global__ __aicore__ void bgmv_expand_##TYPE(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,\
uint32_t indicesSize, __gm__ void* yIn, __gm__ void* yOut,\
uint32_t batchSize, uint32_t numTokensPerCore, \
uint32_t maxLoRARank, uint32_t outputHiddenDim, \
uint32_t sliceOffset, uint32_t outputFullDim) \
{ \
AscendC::TPipe pipe; \
BGMVExpand<TYPE> op(&pipe); \
op.Init(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \
outputHiddenDim, sliceOffset, outputFullDim); \
op.Process(); \
}
// declare all dtype kernel
BGMV_EXPAND_TYPE_DECLARE(half)
#if (__CCE_AICORE__ >= 220)
BGMV_EXPAND_TYPE_DECLARE(bfloat16_t)
#endif
namespace vllm_ascend {
extern void bgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, void* indices, uint32_t indicesSize,
void* yIn, void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
{
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
if (type == AscendType::FP16) {
bgmv_expand_half<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore,
maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim);
} else if (type == AscendType::BF16) {
#if (__CCE_AICORE__ >= 220)
bgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, yIn, yOut, batchSize,
numTokensPerCore, maxLoRARank, outputHiddenDim,
sliceOffset, outputFullDim);
#endif
} else {
return;
}
}
} // namespace vllm_ascend

View File

@@ -0,0 +1,252 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include "types.h"
template <typename scalar_t>
class BGMVShrink {
public:
using X_T = scalar_t;
using W_T = scalar_t;
using Y_T = float;
static constexpr uint64_t BUFFER_NUM = 1;
static constexpr uint64_t TILE_LENGTH = 11776; // optimal performance tile length
public:
__aicore__ inline BGMVShrink(AscendC::TPipe *pipe) : pipe_(pipe) {}
__aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *indices, uint32_t indicesSize, __gm__ void *y,
uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
uint32_t maxLoRARank, float scale)
{
batchSize_ = batchSize;
numTokensPerCore_ = numTokensPerCore;
inputHiddenDim_ = inputHiddenDim;
maxLoRARank_ = maxLoRARank;
scale_ = scale;
singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_;
incremental_ = inputHiddenDim_ > TILE_LENGTH;
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)y);
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices, indicesSize);
pipe_->InitBuffer(inQueueX_, BUFFER_NUM, TILE_LENGTH * sizeof(X_T));
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, TILE_LENGTH * sizeof(W_T));
pipe_->InitBuffer(tmpBufferX_, TILE_LENGTH * sizeof(float));
pipe_->InitBuffer(tmpBufferW_, TILE_LENGTH * sizeof(float));
pipe_->InitBuffer(outQueueY_, 1, maxLoRARank_ * sizeof(Y_T));
pipe_->InitBuffer(outBufferY_, maxLoRARank_ * sizeof(float));
}
__aicore__ inline void Process()
{
int64_t blockIdx = AscendC::GetBlockIdx();
int64_t startIdx = blockIdx * numTokensPerCore_;
int64_t endIdx = startIdx + numTokensPerCore_;
if (endIdx > batchSize_) {
endIdx = batchSize_;
}
for (int64_t idx = startIdx; idx < endIdx; idx++) {
// set up LoRA index
CopyInIndex(idx);
if (reqLoRAIndex_ < 0) {
continue;
}
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
if (incremental_) {
ProcessImpl<true>(idx);
} else {
ProcessImpl<false>(idx);
}
ScaleOutput();
CopyOut(idx);
}
}
private:
template <bool INCREMENTAL_MODE>
__aicore__ inline void ProcessImpl(const int64_t idx)
{
AscendC::LocalTensor<float> yOutLocal = outBufferY_.Get<float>();
if constexpr (!INCREMENTAL_MODE) {
CopyInX(idx, 0, inputHiddenDim_);
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, inputHiddenDim_);
pipe_barrier(PIPE_V);
inQueueX_.FreeTensor(xLocal);
}
for (int i = 0; i < maxLoRARank_; i++) {
float acc(0);
for (int32_t j = 0; j < inputHiddenDim_ / TILE_LENGTH; j++) {
if constexpr (INCREMENTAL_MODE) {
CopyInX(idx, j);
}
CopyInW(i, j);
Compute<INCREMENTAL_MODE>(acc);
}
CopyAndComputeLastIteration<INCREMENTAL_MODE>(idx, i, acc);
yOutLocal.SetValue(i, acc);
}
}
__aicore__ inline void CopyInIndex(const int64_t idx)
{
// look up the LoRA index
reqLoRAIndex_ = indicesGm_.GetValue(idx);
}
__aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
DataCopy(xLocal, xGm_[inputHiddenDim_ * idx + colIdx * TILE_LENGTH], numElements);
inQueueX_.EnQue(xLocal);
}
__aicore__ inline void CopyInW(int32_t rowIdx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + rowIdx * inputHiddenDim_ + colIdx * TILE_LENGTH], numElements);
inQueueW_.EnQue(wLocal);
}
template <bool INCREMENTAL_MODE>
__aicore__ inline void Compute(float &acc, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
if constexpr (INCREMENTAL_MODE) {
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, numElements);
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueX_.FreeTensor(xLocal);
inQueueW_.FreeTensor(wLocal);
} else {
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueW_.FreeTensor(wLocal);
}
// dot product of the one tile of X and W
Mul(wTmpTensor, xTmpTensor, wTmpTensor, numElements);
pipe_barrier(PIPE_V);
// reduce sum generate one number, which is the summation of all the dot product
ReduceSum<float>(wTmpTensor, wTmpTensor, wTmpTensor, numElements);
pipe_barrier(PIPE_V);
acc += wTmpTensor.GetValue(0);
}
template <bool INCREMENTAL_MODE>
__aicore__ inline void CopyAndComputeLastIteration(const int64_t idx, int32_t rowIdx, float &acc)
{
int32_t colIdx = inputHiddenDim_ / TILE_LENGTH;
int32_t remaining = inputHiddenDim_ % TILE_LENGTH;
if (remaining == 0) {
return;
}
if constexpr (INCREMENTAL_MODE) {
CopyInX(idx, colIdx, remaining);
}
CopyInW(rowIdx, colIdx, remaining);
Compute<INCREMENTAL_MODE>(acc, remaining);
}
__aicore__ inline void ScaleOutput()
{
AscendC::LocalTensor<float> yLocal = outBufferY_.Get<float>();
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
Muls(yOutLocal, yLocal, scale_, maxLoRARank_);
pipe_barrier(PIPE_V);
outQueueY_.EnQue<Y_T>(yOutLocal);
}
__aicore__ inline void CopyOut(const int64_t idx)
{
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
DataCopy(yOutGm_[maxLoRARank_ * idx], yOutLocal, maxLoRARank_);
outQueueY_.FreeTensor(yOutLocal);
}
private:
AscendC::TPipe *pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueX_, inQueueW_;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueY_;
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferX_, tmpBufferW_, outBufferY_;
AscendC::GlobalTensor<X_T> xGm_;
AscendC::GlobalTensor<W_T> wGm_;
AscendC::GlobalTensor<int64_t> indicesGm_;
AscendC::GlobalTensor<Y_T> yOutGm_;
uint32_t batchSize_;
uint32_t numTokensPerCore_;
uint32_t inputHiddenDim_;
uint32_t maxLoRARank_;
float scale_;
uint32_t singleLoRAWeightLen_;
int64_t reqLoRAIndex_;
uint64_t reqLoRAWeightOffset_;
bool incremental_;
};
#define BGMV_SHRINK_TYPE_DECLARE(TYPE) \
extern "C" __global__ __aicore__ void bgmv_shrink_##TYPE(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,\
uint32_t indicesSize, __gm__ void* y, uint32_t batchSize, \
uint32_t numTokensPerCore, uint32_t inputHiddenDim, \
uint32_t maxLoRARank, float scale) \
{ \
AscendC::TPipe pipe; \
BGMVShrink<TYPE> op(&pipe); \
op.Init(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \
op.Process(); \
}
// declare all dtype kernel
BGMV_SHRINK_TYPE_DECLARE(half)
#if (__CCE_AICORE__ >= 220)
BGMV_SHRINK_TYPE_DECLARE(bfloat16_t)
#endif
namespace vllm_ascend {
extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, void* indices, uint32_t indicesSize,
void* y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
uint32_t maxLoRARank, float scale)
{
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
if (type == AscendType::FP16) {
bgmv_shrink_half<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore,
inputHiddenDim, maxLoRARank, scale);
} else if (type == AscendType::BF16) {
#if (__CCE_AICORE__ >= 220)
bgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore,
inputHiddenDim, maxLoRARank, scale);
#endif
} else {
return;
}
}
} // namespace vllm_ascend

View File

@@ -0,0 +1,378 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#include "kernel_operator.h"
#include "kernel_tensor_impl.h"
#include "kernel_type.h"
#include "types.h"
#include "utils.h"
using vllm_ascend::AccType;
template<typename scalar_t>
class GetMaskedInputAndMask {
public:
__aicore__ inline GetMaskedInputAndMask() {}
__aicore__ inline ~GetMaskedInputAndMask() {
pipe.Reset();
}
__aicore__ inline void Init(
__gm__ scalar_t* input,
__gm__ scalar_t* masked_input,
__gm__ bool* mask_out,
const int64_t org_vocab_start_index,
const int64_t org_vocab_end_index,
const int64_t num_org_vocab_padding,
const int64_t added_vocab_start_index,
const int64_t added_vocab_end_index,
const int64_t size)
{
// Initialize basic parameters
input_ = input;
masked_input_ = masked_input;
mask_out_ = mask_out;
org_vocab_start_index_ = org_vocab_start_index;
org_vocab_end_index_ = org_vocab_end_index;
size_ = ((size + 31) / 32) * 32;
added_offset_ = added_vocab_start_index -
(org_vocab_end_index - org_vocab_start_index) -
num_org_vocab_padding;
added_vocab_start_index_ = added_vocab_start_index;
added_vocab_end_index_ = added_vocab_end_index;
// Initialize global tensors
inputGlobal.SetGlobalBuffer(input);
maskedOutputGlobal.SetGlobalBuffer(masked_input);
maskOutGlobal.SetGlobalBuffer(mask_out);
// Initialize queues
pipe.InitBuffer(inQueue, 1, size_ * sizeof(scalar_t));
pipe.InitBuffer(outQueue, 1, size_ * sizeof(scalar_t));
pipe.InitBuffer(maskQueue, 1, size_ * sizeof(bool));
// Initialize calculation buffers
// NOTE: calc_buf_1 and calc_buf_2 are also used for int16 casting on older archs.
pipe.InitBuffer(calc_buf_1, size_ * sizeof(float));
pipe.InitBuffer(calc_buf_2, size_ * sizeof(float));
// Initialize result queues
pipe.InitBuffer(result_ge_que, BUFFER_NUM, size_ * sizeof(float));
pipe.InitBuffer(result_le_que, BUFFER_NUM, size_ * sizeof(float));
pipe.InitBuffer(result_org_mask_que, BUFFER_NUM, size_ * sizeof(float));
pipe.InitBuffer(result_add_mask_que, BUFFER_NUM, size_ * sizeof(float));
// Initialize temporary buffers
pipe.InitBuffer(start_buf, size_ * sizeof(float));
pipe.InitBuffer(end_buf, size_ * sizeof(float));
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float)); // Also used for half intermediate in casting
pipe.InitBuffer(validOffset_buf, size_ * sizeof(float));
pipe.InitBuffer(vocabMask_buf_, size_ * sizeof(int8_t));
pipe.InitBuffer(ones_buf_, size_ * sizeof(float));
}
__aicore__ inline void Process()
{
CopyIn();
Compute();
CopyOut();
}
private:
__aicore__ inline void CopyIn()
{
AscendC::LocalTensor<scalar_t> inputLocal = inQueue.AllocTensor<scalar_t>();
AscendC::DataCopy(inputLocal, inputGlobal, size_);
inQueue.EnQue(inputLocal);
}
__aicore__ inline void CompareWithValue(
AscendC::LocalTensor<int8_t>& result,
const AscendC::LocalTensor<float>& input,
const AscendC::LocalTensor<float>& compare_value,
bool is_greater_equal) {
AscendC::LocalTensor<float> compute_buf = calc_buf_1.Get<float>();
if (is_greater_equal) {
AscendC::Max(compute_buf, input, compare_value, size_);
AscendC::Sub(compute_buf, compare_value, compute_buf, size_);
} else {
AscendC::Max(compute_buf, input, compare_value, size_);
AscendC::Sub(compute_buf, compute_buf, compare_value, size_);
}
AscendC::Abs(compute_buf, compute_buf, size_);
AscendC::Mins(compute_buf, compute_buf, MIN_ACCURACY_FP32, size_);
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_);
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_);
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_2_FP32, size_);
AscendC::Adds(compute_buf, compute_buf, NEGATIVE_ONE_FP32, size_);
AscendC::Abs(compute_buf, compute_buf, size_);
AscendC::LocalTensor<half> compute_buf_fp16 = calc_buf_2.Get<half>();
AscendC::Cast(compute_buf_fp16, compute_buf, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(result, compute_buf_fp16, AscendC::RoundMode::CAST_NONE, size_);
}
__aicore__ inline void ComputeRangeMask(
AscendC::LocalTensor<int8_t>& range_mask,
const AscendC::LocalTensor<float>& input,
const float start_value,
const float end_value) {
AscendC::LocalTensor<float> start_value_tensor = start_buf.Get<float>();
AscendC::LocalTensor<float> end_value_tensor = end_buf.Get<float>();
AscendC::Duplicate(start_value_tensor, start_value, size_);
AscendC::Duplicate(end_value_tensor, end_value, size_);
AscendC::LocalTensor<int8_t> ge_result = result_ge_que.AllocTensor<int8_t>();
AscendC::LocalTensor<int8_t> lt_result = result_le_que.AllocTensor<int8_t>();
CompareWithValue(ge_result, start_value_tensor, input, true);
CompareWithValue(lt_result, input, end_value_tensor, false);
#if (__CCE_AICORE__ >= 220)
AscendC::And(range_mask, ge_result, lt_result, size_);
#else
{
// WORKAROUND for older arch
// No direct int8->int16 cast. Use half as intermediate.
// No direct int8 And. Use int16 And.
AscendC::LocalTensor<int16_t> ge_result_i16 = calc_buf_1.Get<int16_t>();
AscendC::LocalTensor<int16_t> lt_result_i16 = calc_buf_2.Get<int16_t>();
AscendC::LocalTensor<int16_t> range_mask_i16 = ge_result_i16;
// Use a temporary buffer for half type
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();
// 1. Cast inputs: int8_t -> half -> int16_t
AscendC::Cast(tmp_half, ge_result, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(ge_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(tmp_half, lt_result, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(lt_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
// 2. Perform And on int16_t tensors
AscendC::And(range_mask_i16, ge_result_i16, lt_result_i16, size_);
// 3. Cast result back: int16_t -> half -> int8_t
AscendC::Cast(tmp_half, range_mask_i16, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(range_mask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
}
#endif
}
__aicore__ inline void Compute() {
AscendC::LocalTensor<scalar_t> inputLocal = inQueue.DeQue<scalar_t>();
AscendC::LocalTensor<scalar_t> maskedLocal = outQueue.AllocTensor<scalar_t>();
AscendC::LocalTensor<int8_t> maskLocal = maskQueue.AllocTensor<int8_t>();
AscendC::LocalTensor<float> inputFloat = inputFloat_buf.Get<float>();
AscendC::Cast(inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_);
AscendC::LocalTensor<int8_t> orgVocabMask = result_org_mask_que.AllocTensor<int8_t>();
ComputeRangeMask(orgVocabMask,
inputFloat,
static_cast<float>(org_vocab_start_index_),
static_cast<float>(org_vocab_end_index_));
AscendC::LocalTensor<int8_t> addedVocabMask = result_add_mask_que.AllocTensor<int8_t>();
ComputeRangeMask(addedVocabMask,
inputFloat,
static_cast<float>(added_vocab_start_index_),
static_cast<float>(added_vocab_end_index_));
AscendC::LocalTensor<float> validOffset = validOffset_buf.Get<float>();
AscendC::LocalTensor<float> constOrgStartIndex = start_buf.Get<float>();
AscendC::Duplicate(constOrgStartIndex, float(org_vocab_start_index_), size_);
AscendC::LocalTensor<half> orgVocabMask_fp16;
AscendC::LocalTensor<float> orgVocabMask_fp32;
AscendC::Cast(orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Mul(validOffset, constOrgStartIndex, orgVocabMask_fp32, size_);
AscendC::LocalTensor<float> addedOffset;
AscendC::LocalTensor<float> addedOffsetTensor = end_buf.Get<float>();
AscendC::Duplicate(addedOffsetTensor, float(added_offset_), size_);
AscendC::LocalTensor<half> addedVocabMask_fp16;
AscendC::LocalTensor<float> addedVocabMask_fp32;
AscendC::Cast(addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Mul(addedOffset, addedOffsetTensor, addedVocabMask_fp32, size_);
AscendC::Add(validOffset, validOffset, addedOffset, size_);
AscendC::LocalTensor<int8_t> vocabMask = vocabMask_buf_.Get<int8_t>();
#if (__CCE_AICORE__ >= 220)
AscendC::Or(vocabMask,
orgVocabMask,
addedVocabMask,
size_);
#else
{
// WORKAROUND for older arch
// No direct int8->int16 cast. Use half as intermediate.
// No direct int8 Or. Use int16 Or.
AscendC::LocalTensor<int16_t> orgVocabMask_i16 = calc_buf_1.Get<int16_t>();
AscendC::LocalTensor<int16_t> addedVocabMask_i16 = calc_buf_2.Get<int16_t>();
AscendC::LocalTensor<int16_t> vocabMask_i16 = orgVocabMask_i16;
// Use a temporary buffer for half type. inputFloat_buf is free now.
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();
// 1. Cast inputs: int8_t -> half -> int16_t
AscendC::Cast(tmp_half, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(orgVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(tmp_half, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(addedVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
// 2. Perform Or on int16_t tensors
AscendC::Or(vocabMask_i16, orgVocabMask_i16, addedVocabMask_i16, size_);
// 3. Cast result back: int16_t -> half -> int8_t
AscendC::Cast(tmp_half, vocabMask_i16, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(vocabMask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
}
#endif
AscendC::Sub(inputFloat, inputFloat, validOffset, size_);
AscendC::LocalTensor<half> vocabMask_fp16;
AscendC::LocalTensor<float> vocabMask_fp32;
AscendC::Cast(vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Mul(inputFloat, inputFloat, vocabMask_fp32, size_);
AscendC::Cast(maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_);
outQueue.EnQue(maskedLocal);
AscendC::LocalTensor<float> ones_tensor = ones_buf_.Get<float>();
AscendC::Duplicate(ones_tensor, (float)1, size_);
AscendC::LocalTensor<float> maskLocal_fp32;
AscendC::Sub(maskLocal_fp32, ones_tensor, vocabMask_fp32, size_);
AscendC::LocalTensor<half> maskLocal_fp16;
AscendC::Cast(maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(maskLocal, maskLocal_fp16, AscendC::RoundMode::CAST_NONE, size_);
maskQueue.EnQue(maskLocal);
inQueue.FreeTensor(inputLocal);
}
__aicore__ inline void CopyOut()
{
AscendC::LocalTensor<scalar_t> maskedLocal = outQueue.DeQue<scalar_t>();
AscendC::LocalTensor<bool> maskLocal = maskQueue.DeQue<bool>();
AscendC::DataCopy(maskedOutputGlobal, maskedLocal, size_);
AscendC::DataCopy(maskOutGlobal, maskLocal, size_);
outQueue.FreeTensor(maskedLocal);
maskQueue.FreeTensor(maskLocal);
}
private:
static constexpr int32_t BUFFER_NUM = 2;
AscendC::TPipe pipe;
AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueue;
AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue, maskQueue;
AscendC::GlobalTensor<scalar_t> inputGlobal, maskedOutputGlobal;
AscendC::GlobalTensor<bool> maskOutGlobal;
AscendC::TBuf<AscendC::TPosition::VECCALC> calc_buf_1;
AscendC::TBuf<AscendC::TPosition::VECCALC> calc_buf_2;
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_ge_que;
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_le_que;
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_org_mask_que;
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_add_mask_que;
// Temporary buffers
AscendC::TBuf<AscendC::TPosition::VECCALC> start_buf;
AscendC::TBuf<AscendC::TPosition::VECCALC> end_buf;
AscendC::TBuf<AscendC::TPosition::VECCALC> inputFloat_buf;
AscendC::TBuf<AscendC::TPosition::VECCALC> validOffset_buf;
AscendC::TBuf<AscendC::TPosition::VECCALC> vocabMask_buf_;
AscendC::TBuf<AscendC::TPosition::VECCALC> ones_buf_;
__gm__ scalar_t *input_, *masked_input_;
__gm__ bool *mask_out_;
int64_t size_;
int64_t org_vocab_start_index_, org_vocab_end_index_;
int64_t added_vocab_start_index_, added_vocab_end_index_;
int64_t added_offset_;
static constexpr float MIN_ACCURACY_FP32 = 1.1754943508222875e-38;
static constexpr float MAX_MUL_1_FP32 = 1125899906842624;
static constexpr float MAX_MUL_2_FP32 = 67108864;
static constexpr float NEGATIVE_ONE_FP32 = -1.0f;
};
extern "C" __global__ __aicore__ void get_masked_input_and_mask_kernel(
__gm__ int32_t* input,
__gm__ int32_t* masked_input,
__gm__ bool* mask_out,
const int64_t org_vocab_start_index,
const int64_t org_vocab_end_index,
const int64_t num_org_vocab_padding,
const int64_t added_vocab_start_index,
const int64_t added_vocab_end_index,
const int64_t size,
const uint32_t loop_cnt,
const uint32_t aiv_num)
{
{
GetMaskedInputAndMask<int32_t> op{};
for (int64_t i = AscendC::GetBlockIdx(); i < loop_cnt; i += aiv_num) {
op.Init(input + i * size/loop_cnt,
masked_input + i * size/loop_cnt,
mask_out + i * size/loop_cnt,
org_vocab_start_index, org_vocab_end_index,
num_org_vocab_padding, added_vocab_start_index,
added_vocab_end_index, size/loop_cnt);
op.Process();
}
} // op destructor called here
}
namespace vllm_ascend {
void get_masked_input_and_mask_impl(
void* stream,
void* input,
void* masked_input,
void* mask_out,
const int64_t org_vocab_start_index,
const int64_t org_vocab_end_index,
const int64_t num_org_vocab_padding,
const int64_t added_vocab_start_index,
const int64_t added_vocab_end_index,
const int64_t size,
const uint32_t loop_cnt,
const uint32_t aiv_num)
{
get_masked_input_and_mask_kernel<<<aiv_num, nullptr, stream>>>(
static_cast<int32_t*>(input),
static_cast<int32_t*>(masked_input),
static_cast<bool*>(mask_out),
org_vocab_start_index,
org_vocab_end_index,
num_org_vocab_padding,
added_vocab_start_index,
added_vocab_end_index,
size,
loop_cnt,
aiv_num);
}
} // namespace vllm_ascend

View File

@@ -0,0 +1,372 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include <stdio.h>
#include "types.h"
#include "utils.h"
using vllm_ascend::AccType;
using vllm_ascend::local_mem_copy;
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
// retrieve this size from runtime for more Soc support
#if (__CCE_AICORE__ >= 220)
static int constexpr loadSize = 512;
#else
static int constexpr loadSize = 1024 * 4;
#endif
using dst_t = scalar_t;
using acc_t = typename AccType<scalar_t>::type;
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
using local_scalar_t = AscendC::LocalTensor<scalar_t>;
using local_acc_t = AscendC::LocalTensor<acc_t>;
using local_dst_t = AscendC::LocalTensor<dst_t>;
public:
__aicore__ inline RotaryEmbedding()
{
}
// Allocate buffers for input and output queue and the temp buffer used during kernel compute process,
// this init process happens only in the kernel compute on a single vector core.
__aicore__ inline void init(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
const int rotDim, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int64_t queryStride, const int64_t keyStride,
const int numHeads, const int numKvHeads, const int headSize, AscendC::TPipe *pipe)
{
pipe_ = pipe;
rotDim_ = rotDim;
// query stride and key stride is used to handle the strided tensor which is not contiguous on num_tokens dim
queryStride_ = queryStride;
keyStride_ = keyStride;
dstQueryStride_ = dstQueryStride;
dstKeyStride_ = dstKeyStride;
numHeads_ = numHeads;
numKvHeads_ = numKvHeads;
headSize_ = headSize;
embedDim_ = rotDim / 2;
pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */);
pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
// 2 temporary calculation buffer
calcTmpBufferOffset_ = 0;
// 1 upcast buffer for bf16 (headSize)
upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2;
// 1 upcast temp buffer for bf16 (2 * embed_dim)
upcastTempBufferOffset_ = upcastInputBufferOffset_ + sizeof(acc_t) * headSize_;
// 2 sin cos upcast buffer for bf16
cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_;
// 2. bf16 path: needs 2 cos sin upcast buffer size
// 3. fp16 path: needs 2 temporary calculation buffer size
tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t);
// need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case
// 2 temporary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp
// buffer(headSize int8), 1 dst_temp buffer(headSize, int32)
pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */);
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
pipe_->InitBuffer(copyBuf_, loadSize);
}
}
__aicore__ inline void update_mem_offset(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
const int rotDim, const int64_t dstQueryStride, const int64_t dstKeyStride,
const int64_t queryStride, const int64_t keyStride, const int numHeads,
const int numKvHeads, const int headSize, const int64_t idx)
{
int64_t pos = positions[idx];
cosSin_.SetGlobalBuffer(cosSinCache + pos * rotDim_, rotDim_);
query_.SetGlobalBuffer(query + queryStride * idx, headSize * numHeads_);
key_.SetGlobalBuffer(key + keyStride * idx, headSize * numKvHeads_);
queryDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(queryDst) + dstQueryStride * idx,
headSize * numHeads_);
keyDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(keyDst) + dstKeyStride * idx, headSize * numKvHeads_);
}
// compute per head for neox on bf16
template <typename acc_t_, typename std::enable_if<!std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
__aicore__ inline void
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
{
// slice dst
local_dst_t dstX = dst;
local_dst_t dstY = dst[embedDim_];
// slice src
local_scalar_t srcX = src;
local_scalar_t srcY = src[embedDim_];
// slice temp buffer
local_acc_t calcTmpBufferX = calcTmpBuffer;
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
// slice upcast input buffer
local_acc_t upcastBufferX = upcastInputBuffer;
local_acc_t upcastBufferY = upcastBufferX[embedDim_];
// dst x calc
Cast(upcastInputBuffer, src, AscendC::RoundMode::CAST_NONE, headSize_);
Mul(calcTmpBufferX, upcastBufferX, cos, embedDim_);
Mul(calcTmpBufferY, upcastBufferY, sin, embedDim_);
Sub(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
Cast(dstX, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
// dst y calc
Mul(calcTmpBufferX, upcastBufferX, sin, embedDim_);
Mul(calcTmpBufferY, upcastBufferY, cos, embedDim_);
Add(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
Cast(dstY, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
}
// compute per head output for neox
template <typename acc_t_, typename std::enable_if<std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
__aicore__ inline void
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
{
// slice dst buffer
local_dst_t dstX = dst;
local_dst_t dstY = dst[embedDim_];
// slice src buffer
local_scalar_t srcX = src;
local_scalar_t srcY = src[embedDim_];
// slice temp buffer
local_acc_t calcTmpBufferX = calcTmpBuffer;
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
// dst x calc
Mul(calcTmpBufferX, srcX, cos, embedDim_);
Mul(calcTmpBufferY, srcY, sin, embedDim_);
Sub(dstX, calcTmpBufferX, calcTmpBufferY, embedDim_);
// dst y calc
Mul(calcTmpBufferX, srcX, sin, embedDim_);
Mul(calcTmpBufferY, srcY, cos, embedDim_);
Add(dstY, calcTmpBufferX, calcTmpBufferY, embedDim_);
}
__aicore__ inline void compute_qk(AscendC::GlobalTensor<scalar_t> srcG, AscendC::GlobalTensor<dst_t> dstG,
local_acc_t localCos, local_acc_t localSin, local_acc_t upcastInputBuffer,
local_acc_t calcTmpBuffer, int loopCnt, int tailHeads, int loadStride,
int headNumPerLoad)
{
for (int loopNum = 0; loopNum < loopCnt; ++loopNum) {
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
local_dst_t dst = outQue_.AllocTensor<dst_t>();
AscendC::DataCopy(src, srcG[loopNum * loadStride], loadStride);
inQue_.EnQue(src);
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
int elem_num = loadStride / sizeof(scalar_t);
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
} else {
local_mem_copy(dst, srcDeque, loadStride);
}
for (int i = 0; i < headNumPerLoad; ++i) {
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
calcTmpBuffer);
}
outQue_.EnQue(dst);
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
AscendC::DataCopy(dstG[loopNum * loadStride], dstDeque, loadStride);
outQue_.FreeTensor(dstDeque);
inQue_.FreeTensor(srcDeque);
}
// process tail
{
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
local_dst_t dst = outQue_.AllocTensor<dst_t>();
AscendC::DataCopy(src, srcG[loopCnt * loadStride], tailHeads * headSize_);
inQue_.EnQue(src);
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
int elem_num = tailHeads * headSize_ / sizeof(scalar_t);
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
} else {
local_mem_copy(dst, srcDeque, tailHeads * headSize_);
}
for (int i = 0; i < tailHeads; ++i) {
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
calcTmpBuffer);
}
outQue_.EnQue(dst);
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
AscendC::DataCopy(dstG[loopCnt * loadStride], dstDeque, tailHeads * headSize_);
outQue_.FreeTensor(dstDeque);
inQue_.FreeTensor(srcDeque);
}
}
__aicore__ inline void compute_function()
{
local_scalar_t cosSinLocal = inQueSinCos_.AllocTensor<scalar_t>();
AscendC::DataCopy(cosSinLocal, cosSin_, embedDim_ * 2);
inQueSinCos_.EnQue(cosSinLocal);
local_scalar_t localSinCosDeque = inQueSinCos_.DeQue<scalar_t>();
local_scalar_t localCos = localSinCosDeque;
local_scalar_t localSin = localSinCosDeque[embedDim_];
local_acc_t calcTmpBuffer;
local_acc_t upcastInputBuffer;
local_acc_t upcastTempBuffer;
local_acc_t cosSinUpcastBuffer;
local_acc_t scaleBuffer;
local_acc_t offsetBuffer;
calcTmpBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, calcTmpBufferOffset_);
upcastInputBuffer = calcBuf_.GetWithOffset<acc_t>(headSize_, upcastInputBufferOffset_);
upcastTempBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, upcastTempBufferOffset_);
cosSinUpcastBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, cosSinUpcastBufferOffset_);
local_acc_t cosAccBuffer;
local_acc_t sinAccBuffer;
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
Cast(cosSinUpcastBuffer, localSinCosDeque, AscendC::RoundMode::CAST_NONE, 2 * embedDim_);
cosAccBuffer = cosSinUpcastBuffer;
sinAccBuffer = cosSinUpcastBuffer[embedDim_];
} else {
cosAccBuffer = localCos;
sinAccBuffer = localSin;
}
constexpr const int loadSizeByElem = loadSize / sizeof(scalar_t);
int64_t headNumPerLoad = loadSizeByElem / headSize_;
int64_t loopCnt = numHeads_ / headNumPerLoad;
int64_t tailHeads = numHeads_ - loopCnt * headNumPerLoad;
int64_t loadStride = headNumPerLoad * headSize_;
int64_t loopCntKv = numKvHeads_ / headNumPerLoad;
int64_t tailHeadsKv = numKvHeads_ - loopCntKv * headNumPerLoad;
compute_qk(query_, queryDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer,
calcTmpBuffer, loopCnt, tailHeads, loadStride, headNumPerLoad);
compute_qk(key_, keyDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, calcTmpBuffer,
loopCntKv, tailHeadsKv, loadStride, headNumPerLoad);
inQueSinCos_.FreeTensor(localSinCosDeque);
}
private:
AscendC::TPipe *pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQue_, inQueSinCos_;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQue_;
AscendC::TBuf<AscendC::TPosition::VECCALC> calcBuf_;
AscendC::TBuf<AscendC::TPosition::VECCALC> copyBuf_;
AscendC::GlobalTensor<dst_t> queryDst_;
AscendC::GlobalTensor<dst_t> keyDst_;
AscendC::GlobalTensor<scalar_t> query_;
AscendC::GlobalTensor<scalar_t> key_;
AscendC::GlobalTensor<scalar_t> cosSin_;
int rotDim_;
int embedDim_;
int64_t queryStride_;
int64_t keyStride_;
int64_t dstQueryStride_;
int64_t dstKeyStride_;
int numHeads_;
int numKvHeads_;
int headSize_;
int calcTmpBufferOffset_;
int upcastInputBufferOffset_;
int upcastTempBufferOffset_;
int cosSinUpcastBufferOffset_;
int tempBufferSize_;
};
// Note: Need to use macro to instaniate all the target functions here, for the current build system dose not support template call in cpp
// We use C style symbol here for kernel compilation, cpp style kernel entry may lead to compilation failure
#define ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, NEOX) \
extern "C" __global__ __aicore__ void rope_custom_##NEOX##_##TYPE( \
__gm__ int64_t* positions, __gm__ void* queryDst, __gm__ void* keyDst, __gm__ TYPE* query, __gm__ TYPE* key, \
__gm__ TYPE* cosSinCache, const int rotDim, const int64_t queryStride, const int64_t keyStride, \
const int64_t dstQueryStride, const int64_t dstKeyStride, const int numHeads, const int numKvHeads, \
const int headSize, const int64_t numTokens, const int loopNum, const int coreNum) \
{ \
AscendC::TPipe pipe; \
RotaryEmbedding<TYPE, NEOX> op{}; \
op.init(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
queryStride, keyStride, numHeads, numKvHeads, headSize, &pipe); \
for (int64_t i = AscendC::GetBlockIdx(); i < numTokens; i += coreNum) { \
op.update_mem_offset(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
queryStride, keyStride, numHeads, numKvHeads, headSize, i); \
op.compute_function(); \
} \
}
#define ROPE_CUSTOM_KERNEL_DECLARE(TYPE) \
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, true); \
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, false);
// Declare all the kernel entry here
ROPE_CUSTOM_KERNEL_DECLARE(half)
#if (__CCE_AICORE__ >= 220)
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
#endif
namespace vllm_ascend {
#define ROTARY_EMBEDDING_KERNEL_CALL(TYPE) \
if (isNeox) \
rope_custom_true_##TYPE<<<blockDim, nullptr, stream>>>( \
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); \
else \
rope_custom_false_##TYPE<<<blockDim, nullptr, stream>>>( \
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim);
// maximum number for runtime to launch a ascendc kernel.
// we use this to constrain the maximum number of block size
static const int64_t maxParallelSize = 65535;
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
uint32_t aivNum)
{
int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize;
if (type == AscendType::FP16) {
ROTARY_EMBEDDING_KERNEL_CALL(half);
}
#if (__CCE_AICORE__ >= 220)
else if (type == AscendType::BF16) {
ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t);
}
#endif
else {
return;
}
}
} // namespace vllm_ascend

View File

@@ -0,0 +1,389 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include "types.h"
template <typename scalar_t>
class SGMVExpand {
public:
using X_T = float;
using W_T = scalar_t;
using Y_T = scalar_t;
static constexpr uint64_t LORA_RANK_8 = 8;
static constexpr uint64_t LORA_RANK_16 = 16;
static constexpr uint64_t LORA_RANK_32 = 32;
static constexpr uint64_t LORA_RANK_64 = 64;
static constexpr uint64_t SUPPORTED_RANKS[] = {LORA_RANK_8, LORA_RANK_16, LORA_RANK_32, LORA_RANK_64};
static constexpr int32_t BUFFER_NUM = 2;
// The vector unit reads 8 blocks (32 bytes each and 256 bytes in total) of contiguous data each time.
static constexpr int32_t NUM_BYTES_PER_REPEAT = 256;
static constexpr int32_t NUM_BLOCKS_PER_REPEAT = 8;
// The maximum number of elements in a single iteration is 256 / sizeof(intermediate data type).
static constexpr int32_t NUM_ELEMENTS_PER_REPEAT = NUM_BYTES_PER_REPEAT / sizeof(float);
// Mask is used to control the elements that participate in computation in each iteration.
static constexpr int32_t MASK_COUNT = NUM_BYTES_PER_REPEAT / sizeof(float);
// Refer to numOutputElementsPerInputTile_ initialization for the constraints on the following constants.
static constexpr int32_t W_IN_TILE_NUM_ELEMENTS = 8192;
static constexpr int32_t Y_OUT_TILE_NUM_ELEMENTS = 4096;
static constexpr int32_t BLOCK_REDUCE_NUM_REPEATS = W_IN_TILE_NUM_ELEMENTS / NUM_ELEMENTS_PER_REPEAT;
// BlockReduceSum would generate(BLOCK_REDUCE_NUM_REPEATS * NUM_BLOCKS_PER_REPEAT)floats.
// So need to read them all and apply PairReduceSum
static constexpr int32_t PAIR_REDUCE_NUM_REPEATS_16 =
(BLOCK_REDUCE_NUM_REPEATS * NUM_BLOCKS_PER_REPEAT + NUM_ELEMENTS_PER_REPEAT - 1) / NUM_ELEMENTS_PER_REPEAT;
// The second PairReduceSum for rank=32, needs half of the repetition that happened for rank=16.
// Same for rank=64, we do not support ranks greater than 64.
static constexpr int32_t PAIR_REDUCE_NUM_REPEATS_32 = (PAIR_REDUCE_NUM_REPEATS_16 + 1) / 2;
public:
__aicore__ inline SGMVExpand(AscendC::TPipe* pipe) : pipe_(pipe) {}
__aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* loraIndices, uint32_t loraIndicesSize,
__gm__ void* seqLen, uint32_t seqLenSize, __gm__ void* yIn, __gm__ void* yOut,
uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
{
batchSize_ = batchSize;
numTokensPerCore_ = numTokensPerCore;
maxLoRARank_ = maxLoRARank;
outputHiddenDim_ = outputHiddenDim;
sliceOffset_ = sliceOffset;
outputFullDim_ = outputFullDim;
singleLoRAWeightLen_ = maxLoRARank_ * outputHiddenDim_;
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
yInGm_.SetGlobalBuffer((__gm__ Y_T *)yIn);
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)yOut);
loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices, loraIndicesSize);
seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen, seqLenSize);
pipe_->InitBuffer(inQueueX_, 1, NUM_ELEMENTS_PER_REPEAT * sizeof(X_T));
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, W_IN_TILE_NUM_ELEMENTS * sizeof(W_T));
pipe_->InitBuffer(inQueueY_, BUFFER_NUM, Y_OUT_TILE_NUM_ELEMENTS * sizeof(Y_T));
pipe_->InitBuffer(outQueueY_, BUFFER_NUM, Y_OUT_TILE_NUM_ELEMENTS * sizeof(Y_T));
pipe_->InitBuffer(dupBufferX_, NUM_ELEMENTS_PER_REPEAT * sizeof(float));
pipe_->InitBuffer(tmpBufferW_, W_IN_TILE_NUM_ELEMENTS * sizeof(float));
pipe_->InitBuffer(inBufferY_, Y_OUT_TILE_NUM_ELEMENTS * sizeof(float));
pipe_->InitBuffer(tmpBufferY_, Y_OUT_TILE_NUM_ELEMENTS * sizeof(float));
// Each compute iteration would generate not one, but several output elements.
// Therefore, the following variable would determine how many output elements are calculated in each iteration.
numOutputElementsPerInputTile_ = BLOCK_REDUCE_NUM_REPEATS * (NUM_ELEMENTS_PER_REPEAT / maxLoRARank_);
numStreamInPerOutputTile_ = Y_OUT_TILE_NUM_ELEMENTS / numOutputElementsPerInputTile_;
}
__aicore__ inline void Process()
{
int64_t blockIdx = AscendC::GetBlockIdx();
int64_t startIdx = blockIdx * numTokensPerCore_;
int64_t endIdx = startIdx + numTokensPerCore_;
if (endIdx > batchSize_) {
endIdx = batchSize_;
}
for (int64_t idx = startIdx; idx < endIdx; idx++) {
yOffset_ = outputFullDim_ * idx + sliceOffset_;
// Set up LoRA index
CopyInIndex(idx);
if (reqLoRAIndex_ < 0) {
continue;
}
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
CopyInX(idx);
int32_t numStreamOut = outputHiddenDim_ / Y_OUT_TILE_NUM_ELEMENTS;
for (int32_t i = 0; i < numStreamOut; i++) {
CopyInY(i);
for (int32_t j = 0; j < numStreamInPerOutputTile_; j++) {
CopyInW(i * numStreamInPerOutputTile_ + j);
Compute(j * numOutputElementsPerInputTile_);
}
ScaleOutput();
CopyOut(i);
}
ComputeLastIteration();
}
}
private:
__aicore__ inline void CopyInIndex(const int64_t idx)
{
// Look up the LoRA index
int64_t weightIdx = idx;
uint64_t i = 0;
for (; i < seqLenGm_.GetSize(); i++) {
int64_t repeatValue = seqLenGm_.GetValue(i);
if (weightIdx >= repeatValue) {
weightIdx -= repeatValue;
continue;
}
break;
}
reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1;
}
__aicore__ inline void ComputeLastIteration()
{
int32_t remainingY = outputHiddenDim_ % Y_OUT_TILE_NUM_ELEMENTS;
if (remainingY == 0) {
return;
}
int32_t numStreamOut = outputHiddenDim_ / Y_OUT_TILE_NUM_ELEMENTS;
int32_t remainingW = remainingY * maxLoRARank_;
int32_t numCompleteWTileInForLastIteration = remainingW / W_IN_TILE_NUM_ELEMENTS;
int32_t remainingWForLastRepeat = remainingW % W_IN_TILE_NUM_ELEMENTS;
CopyInY(numStreamOut, remainingY);
int32_t outputIdx = 0;
for (outputIdx = 0; outputIdx < numCompleteWTileInForLastIteration; outputIdx++) {
CopyInW(numStreamOut * numStreamInPerOutputTile_ + outputIdx);
Compute(outputIdx * numOutputElementsPerInputTile_);
}
if (remainingWForLastRepeat != 0) {
CopyInW(numStreamOut * numStreamInPerOutputTile_ + numCompleteWTileInForLastIteration,
remainingWForLastRepeat);
int32_t lastRepeatCount = remainingWForLastRepeat / NUM_ELEMENTS_PER_REPEAT;
int32_t pairReduceRepeat16 =
(lastRepeatCount * NUM_BLOCKS_PER_REPEAT + NUM_ELEMENTS_PER_REPEAT - 1) / NUM_ELEMENTS_PER_REPEAT;
int32_t pairReduceRepeat32 = (pairReduceRepeat16 + 1) / 2;
int32_t lastComputeOutputElement = outputIdx * numOutputElementsPerInputTile_;
Compute(lastComputeOutputElement, lastRepeatCount, pairReduceRepeat16, pairReduceRepeat32);
}
ScaleOutput(remainingY);
CopyOut(numStreamOut, remainingY);
}
__aicore__ inline void CopyInX(const int64_t idx)
{
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
if constexpr (std::is_same_v<X_T, float>) {
DataCopy(xLocal, xGm_[maxLoRARank_ * idx], maxLoRARank_);
} else {
uint16_t blockLen = static_cast<uint16_t>(maxLoRARank_ * sizeof(X_T));
DataCopyPad(xLocal, xGm_[maxLoRARank_ * idx], {1, blockLen, 0, 0}, {});
}
inQueueX_.EnQue(xLocal);
xLocal = inQueueX_.DeQue<X_T>();
AscendC::LocalTensor<float> xDup = dupBufferX_.Get<float>();
// As we are generating multiple output elements with one API invocation,
// we need to duplicate the X vector multiple times to fill one NUM_BYTES_PER_REPEAT
if constexpr (std::is_same_v<X_T, float>) {
for (int32_t i = 0; i < NUM_ELEMENTS_PER_REPEAT; i += maxLoRARank_) {
for (int32_t j = 0; j < maxLoRARank_; j++) {
float entry = xLocal.GetValue(j);
xDup.SetValue(i + j, entry);
}
}
} else {
Cast(xDup, xLocal, AscendC::RoundMode::CAST_NONE, maxLoRARank_);
pipe_barrier(PIPE_V);
for (int32_t i = maxLoRARank_; i < NUM_ELEMENTS_PER_REPEAT; i += maxLoRARank_) {
for (int32_t j = 0; j < maxLoRARank_; j++) {
float entry = xDup.GetValue(j);
xDup.SetValue(i + j, entry);
}
}
}
inQueueX_.FreeTensor(xLocal);
}
__aicore__ inline void CopyInY(int32_t progress, int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<Y_T> yInLocal = inQueueY_.AllocTensor<Y_T>();
DataCopy(yInLocal, yInGm_[yOffset_ + progress * Y_OUT_TILE_NUM_ELEMENTS], numElements);
inQueueY_.EnQue(yInLocal);
}
__aicore__ inline void CopyInW(int32_t progress, int32_t numElements = W_IN_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + progress * W_IN_TILE_NUM_ELEMENTS], numElements);
inQueueW_.EnQue(wLocal);
}
__aicore__ inline void ScaleOutput(int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<float> yLocal = tmpBufferY_.Get<float>();
AscendC::LocalTensor<Y_T> yInLocal = inQueueY_.DeQue<Y_T>();
AscendC::LocalTensor<float> yInLocalFP32 = inBufferY_.Get<float>();
Cast(yInLocalFP32, yInLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueY_.FreeTensor(yInLocal);
Add(yLocal, yLocal, yInLocalFP32, numElements);
pipe_barrier(PIPE_V);
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
Cast(yOutLocal, yLocal, AscendC::RoundMode::CAST_RINT, numElements);
pipe_barrier(PIPE_V);
outQueueY_.EnQue<Y_T>(yOutLocal);
}
__aicore__ inline void Compute(int32_t progress,
int32_t blockReduceRepeatCount=BLOCK_REDUCE_NUM_REPEATS,
int32_t pairReduceRepeat16=PAIR_REDUCE_NUM_REPEATS_16,
int32_t pairReduceRepeat32=PAIR_REDUCE_NUM_REPEATS_32)
{
AscendC::LocalTensor<float> yLocal = tmpBufferY_.Get<float>();
AscendC::LocalTensor<float> xDup = dupBufferX_.Get<float>();
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, MASK_COUNT, blockReduceRepeatCount, castParams_);
pipe_barrier(PIPE_V);
inQueueW_.FreeTensor(wLocal);
Mul(wTmpTensor, xDup, wTmpTensor, MASK_COUNT, blockReduceRepeatCount, dotProductParams_);
pipe_barrier(PIPE_V);
if (maxLoRARank_ == LORA_RANK_8) {
BlockReduceSum(yLocal[progress], wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_16) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_32) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(wTmpTensor, wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat32, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_64) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
BlockReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
}
}
__aicore__ inline void CopyOut(int32_t progress, int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
DataCopy(yOutGm_[yOffset_ + progress * Y_OUT_TILE_NUM_ELEMENTS], yOutLocal, numElements);
outQueueY_.FreeTensor(yOutLocal);
}
private:
AscendC::TPipe* pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueY_, inQueueW_;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueX_;
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> outQueueY_;
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferW_, dupBufferX_, inBufferY_, tmpBufferY_;
AscendC::GlobalTensor<X_T> xGm_;
AscendC::GlobalTensor<W_T> wGm_;
AscendC::GlobalTensor<Y_T> yInGm_;
AscendC::GlobalTensor<Y_T> yOutGm_;
AscendC::GlobalTensor<int64_t> loraIndicesGm_;
AscendC::GlobalTensor<int64_t> seqLenGm_;
uint32_t batchSize_;
uint32_t numTokensPerCore_;
uint32_t maxLoRARank_;
uint32_t outputHiddenDim_;
uint32_t sliceOffset_;
uint32_t outputFullDim_;
uint32_t singleLoRAWeightLen_;
int64_t reqLoRAIndex_;
uint64_t reqLoRAWeightOffset_;
uint32_t numOutputElementsPerInputTile_;
uint32_t numStreamInPerOutputTile_;
uint64_t yOffset_;
// The block stride is set to 1, and 8 blocks in the same repeat are processed continuously.
// The repeat stride is 8, so the vector unit reads 8 consecutive blocks in the first repeat,
// reads next 8 consecutive blocks in the second repeat.
AscendC::UnaryRepeatParams castParams_ = {1, 1, 8, 4};
// For each repeat in BlockReduceSum and PairReduceSum we should move forward only one block,
// so we set dstRepStride = 1
AscendC::UnaryRepeatParams reduceSumParams_ = {1, 1, 1, 8};
// When the repeat stride is 0, the vector unit repeatedly reads and computes the first 8 consecutive blocks.
// For xDup we repeatedly use it, so we set src0RepStride = 0
AscendC::BinaryRepeatParams dotProductParams_ = {1, 1, 1, 8, 0, 8};
};
#define SGMV_EXPAND_TYPE_DECLARE(TYPE) \
extern "C" __global__ __aicore__ void sgmv_expand_##TYPE(__gm__ void* x, __gm__ void* weight, \
__gm__ void* loraIndices, uint32_t loraIndicesSize, \
__gm__ void* seqLen, uint32_t seqLenSize, \
__gm__ void* yIn, __gm__ void* yOut, \
uint32_t batchSize, uint32_t numTokensPerCore, \
uint32_t maxLoRARank, uint32_t outputHiddenDim, \
uint32_t sliceOffset, uint32_t outputFullDim) \
{ \
AscendC::TPipe pipe; \
SGMVExpand<TYPE> op(&pipe); \
op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, \
yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \
outputHiddenDim, sliceOffset, outputFullDim); \
op.Process(); \
}
// declare all dtype kernel
SGMV_EXPAND_TYPE_DECLARE(half)
#if (__CCE_AICORE__ >= 220)
SGMV_EXPAND_TYPE_DECLARE(bfloat16_t)
#endif
namespace vllm_ascend {
extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weight,
void* loraIndices, uint32_t loraIndicesSize,
void* seqLen, uint32_t seqLenSize,
void* yIn, void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
{
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
if (type == AscendType::FP16) {
sgmv_expand_half<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize,
yIn, yOut, batchSize,
numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset,
outputFullDim);
} else if (type == AscendType::BF16) {
#if (__CCE_AICORE__ >= 220)
sgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize,
seqLen, seqLenSize, yIn, yOut, batchSize,
numTokensPerCore, maxLoRARank, outputHiddenDim,
sliceOffset, outputFullDim);
#endif
} else {
return;
}
}
} // namespace vllm_ascend

View File

@@ -0,0 +1,275 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include "types.h"
template <typename scalar_t>
class SGMVShrink {
public:
using X_T = scalar_t;
using W_T = scalar_t;
using Y_T = float;
static constexpr uint64_t BUFFER_NUM = 1;
static constexpr uint64_t TILE_LENGTH = 11776; // optimal performance tile length
public:
__aicore__ inline SGMVShrink(AscendC::TPipe *pipe) : pipe_(pipe) {}
__aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *loraIndices, uint32_t loraIndicesSize,
__gm__ void *seqLen, uint32_t seqLenSize,
__gm__ void *y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
uint32_t maxLoRARank, float scale)
{
batchSize_ = batchSize;
numTokensPerCore_ = numTokensPerCore;
inputHiddenDim_ = inputHiddenDim;
maxLoRARank_ = maxLoRARank;
scale_ = scale;
singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_;
incremental_ = inputHiddenDim_ > TILE_LENGTH;
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)y);
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices, loraIndicesSize);
seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen, seqLenSize);
pipe_->InitBuffer(inQueueX_, BUFFER_NUM, TILE_LENGTH * sizeof(X_T));
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, TILE_LENGTH * sizeof(W_T));
pipe_->InitBuffer(tmpBufferX_, TILE_LENGTH * sizeof(float));
pipe_->InitBuffer(tmpBufferW_, TILE_LENGTH * sizeof(float));
pipe_->InitBuffer(outQueueY_, 1, maxLoRARank_ * sizeof(Y_T));
pipe_->InitBuffer(outBufferY_, maxLoRARank_ * sizeof(float));
}
__aicore__ inline void Process()
{
int64_t blockIdx = AscendC::GetBlockIdx();
int64_t startIdx = blockIdx * numTokensPerCore_;
int64_t endIdx = startIdx + numTokensPerCore_;
if (endIdx > batchSize_) {
endIdx = batchSize_;
}
for (int64_t idx = startIdx; idx < endIdx; idx++) {
// set up LoRA index
CopyInIndex(idx);
if (reqLoRAIndex_ < 0) {
continue;
}
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
if (incremental_) {
ProcessImpl<true>(idx);
} else {
ProcessImpl<false>(idx);
}
ScaleOutput();
CopyOut(idx);
}
}
private:
template <bool INCREMENTAL_MODE>
__aicore__ inline void ProcessImpl(const int64_t idx)
{
AscendC::LocalTensor<float> yOutLocal = outBufferY_.Get<float>();
if constexpr (!INCREMENTAL_MODE) {
CopyInX(idx, 0, inputHiddenDim_);
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, inputHiddenDim_);
pipe_barrier(PIPE_V);
inQueueX_.FreeTensor(xLocal);
}
for (int i = 0; i < maxLoRARank_; i++) {
float acc(0);
for (int32_t j = 0; j < inputHiddenDim_ / TILE_LENGTH; j++) {
if constexpr (INCREMENTAL_MODE) {
CopyInX(idx, j);
}
CopyInW(i, j);
Compute<INCREMENTAL_MODE>(acc);
}
CopyAndComputeLastIteration<INCREMENTAL_MODE>(idx, i, acc);
yOutLocal.SetValue(i, acc);
}
}
__aicore__ inline void CopyInIndex(const int64_t idx)
{
// look up the LoRA index
int64_t weightIdx = idx;
uint64_t i = 0;
for (; i < seqLenGm_.GetSize(); i++) {
int64_t repeatValue = seqLenGm_.GetValue(i);
if (weightIdx >= repeatValue) {
weightIdx -= repeatValue;
continue;
}
break;
}
reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1;
}
__aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
DataCopy(xLocal, xGm_[inputHiddenDim_ * idx + colIdx * TILE_LENGTH], numElements);
inQueueX_.EnQue(xLocal);
}
__aicore__ inline void CopyInW(int32_t rowIdx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + rowIdx * inputHiddenDim_ + colIdx * TILE_LENGTH], numElements);
inQueueW_.EnQue(wLocal);
}
template <bool INCREMENTAL_MODE>
__aicore__ inline void Compute(float &acc, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
if constexpr (INCREMENTAL_MODE) {
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, numElements);
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueX_.FreeTensor(xLocal);
inQueueW_.FreeTensor(wLocal);
} else {
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueW_.FreeTensor(wLocal);
}
// dot product of the one tile of X and W
Mul(wTmpTensor, xTmpTensor, wTmpTensor, numElements);
pipe_barrier(PIPE_V);
// reduce sum generate one number, which is the summation of all the dot product
ReduceSum<float>(wTmpTensor, wTmpTensor, wTmpTensor, numElements);
pipe_barrier(PIPE_V);
acc += wTmpTensor.GetValue(0);
}
template <bool INCREMENTAL_MODE>
__aicore__ inline void CopyAndComputeLastIteration(const int64_t idx, int32_t rowIdx, float &acc)
{
int32_t colIdx = inputHiddenDim_ / TILE_LENGTH;
int32_t remaining = inputHiddenDim_ % TILE_LENGTH;
if (remaining == 0) {
return;
}
if constexpr (INCREMENTAL_MODE) {
CopyInX(idx, colIdx, remaining);
}
CopyInW(rowIdx, colIdx, remaining);
Compute<INCREMENTAL_MODE>(acc, remaining);
}
__aicore__ inline void ScaleOutput()
{
AscendC::LocalTensor<float> yLocal = outBufferY_.Get<float>();
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
Muls(yOutLocal, yLocal, scale_, maxLoRARank_);
pipe_barrier(PIPE_V);
outQueueY_.EnQue<Y_T>(yOutLocal);
}
__aicore__ inline void CopyOut(const int64_t idx)
{
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
DataCopy(yOutGm_[maxLoRARank_ * idx], yOutLocal, maxLoRARank_);
outQueueY_.FreeTensor(yOutLocal);
}
private:
AscendC::TPipe *pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueX_, inQueueW_;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueY_;
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferX_, tmpBufferW_, outBufferY_;
AscendC::GlobalTensor<X_T> xGm_;
AscendC::GlobalTensor<W_T> wGm_;
AscendC::GlobalTensor<int64_t> loraIndicesGm_;
AscendC::GlobalTensor<int64_t> seqLenGm_;
AscendC::GlobalTensor<Y_T> yOutGm_;
uint32_t batchSize_;
uint32_t numTokensPerCore_;
uint32_t inputHiddenDim_;
uint32_t maxLoRARank_;
float scale_;
uint32_t singleLoRAWeightLen_;
int64_t reqLoRAIndex_;
uint64_t reqLoRAWeightOffset_;
bool incremental_;
};
#define SGMV_SHRINK_TYPE_DECLARE(TYPE) \
extern "C" __global__ __aicore__ void sgmv_shrink_##TYPE(__gm__ void* x, __gm__ void* weight, \
__gm__ void* loraIndices, uint32_t loraIndicesSize, \
__gm__ void* seqLen, uint32_t seqLenSize, \
__gm__ void* y, uint32_t batchSize, \
uint32_t numTokensPerCore, uint32_t inputHiddenDim, \
uint32_t maxLoRARank, float scale) \
{ \
AscendC::TPipe pipe; \
SGMVShrink<TYPE> op(&pipe); \
op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, \
y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \
op.Process(); \
}
// declare all dtype kernel
SGMV_SHRINK_TYPE_DECLARE(half)
#if (__CCE_AICORE__ >= 220)
SGMV_SHRINK_TYPE_DECLARE(bfloat16_t)
#endif
namespace vllm_ascend {
extern void sgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight,
void* loraIndices, uint32_t loraIndicesSize,
void* seqLen, uint32_t seqLenSize,
void* y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
uint32_t maxLoRARank, float scale)
{
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
if (type == AscendType::FP16) {
sgmv_shrink_half<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize,
y, batchSize,
numTokensPerCore, inputHiddenDim, maxLoRARank,
scale);
} else if (type == AscendType::BF16) {
#if (__CCE_AICORE__ >= 220)
sgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize,
seqLen, seqLenSize,
y, batchSize,
numTokensPerCore, inputHiddenDim, maxLoRARank,
scale);
#endif
} else {
return;
}
}
} // namespace vllm_ascend

25
csrc/kernels/types.h Normal file
View File

@@ -0,0 +1,25 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace vllm_ascend {
enum struct AscendType {
FP16 = 0,
BF16 = 1,
FP32 = 2,
};
}

51
csrc/kernels/utils.h Normal file
View File

@@ -0,0 +1,51 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "kernel_type.h"
namespace vllm_ascend {
template <typename scalar_t> struct AccType;
#if (__CCE_AICORE__ >= 220)
template <> struct AccType<bfloat16_t> {
using type = float;
};
#endif
template <> struct AccType<half> {
using type = half;
};
template <> struct AccType<float> {
using type = float;
};
template <> struct AccType<int8_t> {
using type = int;
};
template <typename scalar_t>
__aicore__ inline void local_mem_copy(AscendC::LocalTensor<scalar_t> dst, AscendC::LocalTensor<scalar_t> src, int size)
{
constexpr int loadSize = 256 / sizeof(scalar_t);
int loopCnt = size / loadSize;
int tailSize = size % loadSize;
if (loopCnt)
AscendC::Copy(dst, src, loadSize, loopCnt, {1, 1, 8, 8});
AscendC::Copy(dst[loopCnt * loadSize], src[loopCnt * loadSize], tailSize, 1, {1, 1, 8, 8});
}
} // namespace vllm_ascend

127
csrc/ops.h Normal file
View File

@@ -0,0 +1,127 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <optional>
#include <torch/library.h>
#include <vector>
#include "kernels/types.h"
#include "torch_npu/csrc/aten/common/from_blob.h"
namespace vllm_ascend {
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
uint32_t aivNum);
extern void get_masked_input_and_mask_impl(
void* stream,
void* input,
void* masked_input,
void* mask_out,
const int64_t org_vocab_start_index,
const int64_t org_vocab_end_index,
const int64_t num_org_vocab_padding,
const int64_t added_vocab_start_index,
const int64_t added_vocab_end_index,
const int64_t size,
const uint32_t loop_cnt,
const uint32_t aiv_num);
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
if (!tensor.is_privateuseone()) {
throw std::runtime_error("Tensor must be on NPU device");
}
// Get the raw data pointer
void* data_ptr = tensor.data_ptr();
// Get tensor sizes and strides
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();
// Get tensor options (dtype, device)
auto options = tensor.options();
// Create a new tensor from the raw data pointer
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
return new_tensor;
}
extern void bgmv_shrink_impl(
AscendType type,
void *stream,
void *x,
void *weight,
void *indices,
uint32_t indicesSize,
void *y,
uint32_t batch_size,
uint32_t num_tokens_per_core,
uint32_t input_hidden_dim,
uint32_t lora_rank,
float scale);
extern void bgmv_expand_impl(
AscendType type,
void *stream,
void *x,
void *weight,
void *indices,
uint32_t indicesSize,
void *y,
void *y_out,
uint32_t batch_size,
uint32_t num_tokens_per_core,
uint32_t lora_rank,
uint32_t output_hidden_dim,
uint32_t slice_offset,
uint32_t output_full_dim);
extern void sgmv_shrink_impl(
AscendType type,
void *stream,
void *x,
void *weight,
void *loraIndices,
uint32_t loraIndicesSize,
void *seqLen,
uint32_t seqLenSize,
void *y,
uint32_t batch_size,
uint32_t num_tokens_per_core,
uint32_t input_hidden_dim,
uint32_t lora_rank,
float scale);
extern void sgmv_expand_impl(
AscendType type,
void *stream,
void *x,
void *weight,
void *loraIndices,
uint32_t loraIndicesSize,
void *seqLen,
uint32_t seqLenSize,
void *y,
void *y_out,
uint32_t batch_size,
uint32_t num_tokens_per_core,
uint32_t lora_rank,
uint32_t output_hidden_dim,
uint32_t slice_offset,
uint32_t output_full_dim);
}

428
csrc/torch_binding.cpp Normal file
View File

@@ -0,0 +1,428 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/version.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/npu/Module.h>
#include <pybind11/pybind11.h>
#include "acl/acl.h"
#include "ops.h"
#include "utils.h"
namespace vllm_ascend {
AscendType get_dtype_from_torch(at::ScalarType scalarType)
{
if (scalarType == at::ScalarType::Float) {
return AscendType::FP32;
} else if (scalarType == at::ScalarType::BFloat16) {
return AscendType::BF16;
} else {
return AscendType::FP16;
}
}
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
{
int32_t deviceId = 0;
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
int rot_dim = cos_sin_cache.size(1);
int seq_dim_idx = positions_ndim - 1;
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
void *query_dst_ptr = query_dst.data_ptr();
void *key_dst_ptr = key_dst.data_ptr();
void *query_ptr = query.data_ptr();
void *key_ptr = key.data_ptr();
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);
int64_t dst_query_stride = query_dst.stride(0);
int64_t dst_key_stride = key_dst.stride(0);
at::ScalarType scalar_type = query.scalar_type();
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("rotary_embedding");
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
auto dtype_num = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
uint32_t loop_cnt = (num_tokens + aiv_num - 1) / aiv_num;
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aiv_num);
return 0;
});
cmd.Run();
return {query_dst, key_dst};
}
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
at::Tensor &input,
const int64_t org_vocab_start_index,
const int64_t org_vocab_end_index,
const int64_t num_org_vocab_padding,
const int64_t added_vocab_start_index,
const int64_t added_vocab_end_index)
/*
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L161-L198
Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Parameters:
org_vocab_start_index //base embeddings start
org_vocab_end_index //base embeddings end
num_org_vocab_padding //base embeddings padding
added_vocab_start_index //LoRA embeddings start
added_vocab_end_index //LoRA embeddings end
*/
{
// Input validation
TORCH_CHECK(input.dim() >= 1, "input must have at least 1 dimension");
TORCH_CHECK(org_vocab_start_index >= 0, "org_vocab_start_index must be non-negative");
TORCH_CHECK(org_vocab_end_index >= org_vocab_start_index, "org_vocab_end_index must be greater than org_vocab_start_index");
TORCH_CHECK(num_org_vocab_padding >= 0, "num_org_vocab_padding must be non-negative");
TORCH_CHECK(added_vocab_start_index >= org_vocab_end_index, "added_vocab_start_index must be greater than org_vocab_end_index");
TORCH_CHECK(added_vocab_end_index >= added_vocab_start_index, "added_vocab_end_index must be greater than added_vocab_start_index");
// Get total number of elements
int64_t size = input.numel();
// Create output tensors
at::Tensor masked_input = at::empty_like(input);
at::Tensor mask = at::empty_like(input).to(at::kBool);
// Get data pointers
void *input_ptr = input.data_ptr();
void *masked_input_ptr = masked_input.data_ptr();
void *mask_ptr = mask.data_ptr();
// Get current stream
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
// Get scalar type
at::ScalarType scalar_type = input.scalar_type();
// Create and configure OpCommand
at_npu::native::OpCommand cmd;
cmd.Name("get_masked_input_and_mask");
cmd.SetCustomHandler([scalar_type, size, stream,
input_ptr, masked_input_ptr, mask_ptr,
org_vocab_start_index, org_vocab_end_index,
num_org_vocab_padding, added_vocab_start_index,
added_vocab_end_index]() -> int {
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
uint32_t loop_cnt = (size + aiv_num - 1) / aiv_num;
// Call implementation
get_masked_input_and_mask_impl(
stream,
input_ptr,
masked_input_ptr,
mask_ptr,
org_vocab_start_index,
org_vocab_end_index,
num_org_vocab_padding,
added_vocab_start_index,
added_vocab_end_index,
size,
loop_cnt,
aiv_num);
return 0;
});
cmd.Run();
return {masked_input, mask};
}
void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, double scale)
{
at::ScalarType scalar_type = x.scalar_type();
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
"the first dimension of x, y, indices should be same");
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
void* x_ptr = x.data_ptr();
void* weight_ptr = weight.data_ptr();
void* indices_ptr = indices.data_ptr();
int indices_size = indices.size(0);
void* y_ptr = y.data_ptr();
int batch_size = x.size(0);
int input_hidden_token = x.size(1);
uint32_t lora_rank = y.size(1);
float scale_f = static_cast<float>(scale);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("bgmv_shrink");
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, batch_size, input_hidden_token,
lora_rank, scale_f]() -> int {
auto dtype = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, batch_size, num_tokens_per_core,
input_hidden_token, lora_rank, scale_f);
return 0;
});
cmd.Run();
return;
}
at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y,
int64_t slice_offset, int64_t slice_size)
{
at::ScalarType scalar_type = y.scalar_type();
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
"the first dimension of x, y, indices should be same");
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
"slice_size + slice_offset should be smaller than the second dimension of y")
at::Tensor y_out = y;
void* x_ptr = x.data_ptr();
void* weight_ptr = weight.data_ptr();
void* indices_ptr = indices.data_ptr();
int indices_size = indices.size(0);
void* y_ptr = y.data_ptr();
void* y_out_ptr = y_out.data_ptr();
int batch_size = x.size(0);
int lora_rank = x.size(1);
int output_full_dim = y.size(1);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("bgmv_expand");
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, y_out_ptr, batch_size, lora_rank,
slice_offset, slice_size, output_full_dim]() -> int {
auto dtype = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, y_out_ptr, batch_size,
num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
return 0;
});
cmd.Run();
return y_out;
}
void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
at::Tensor &y, double scale)
{
at::ScalarType scalar_type = x.scalar_type();
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
void* x_ptr = x.data_ptr();
void* weight_ptr = weight.data_ptr();
void* lora_indices_ptr = lora_indices.data_ptr();
void* seq_len_ptr = seq_len.data_ptr();
int lora_indices_size = lora_indices.size(0);
int seq_len_size = seq_len.size(0);
void* y_ptr = y.data_ptr();
int batch_size = x.size(0);
int input_hidden_token = x.size(1);
uint32_t lora_rank = y.size(1);
float scale_f = static_cast<float>(scale);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("sgmv_shrink");
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
seq_len_ptr, seq_len_size, y_ptr,
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
auto dtype = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
y_ptr, batch_size,
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
return 0;
});
cmd.Run();
return;
}
at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
at::Tensor &y, int64_t slice_offset, int64_t slice_size)
{
at::ScalarType scalar_type = y.scalar_type();
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
"slice_size + slice_offset should be smaller than the second dimension of y")
at::Tensor y_out = y;
void* x_ptr = x.data_ptr();
void* weight_ptr = weight.data_ptr();
void* lora_indices_ptr = lora_indices.data_ptr();
void* seq_len_ptr = seq_len.data_ptr();
int lora_indices_size = lora_indices.size(0);
int seq_len_size = seq_len.size(0);
void* y_ptr = y.data_ptr();
void* y_out_ptr = y_out.data_ptr();
int batch_size = x.size(0);
int lora_rank = x.size(1);
int output_full_dim = y.size(1);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("sgmv_expand");
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
auto dtype = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
return 0;
});
cmd.Run();
return y_out;
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(_C, ops)
{
// vLLM-Ascend custom ops
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
// Rotary embedding
// Apply GPT-NeoX style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
ops.def(
"get_masked_input_and_mask(Tensor input, "
" int org_vocab_start_index, "
" int org_vocab_end_index, "
" int num_org_vocab_padding, "
" int added_vocab_start_index, "
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
ops.def(
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
" int slice_offset, int slice_size) -> Tensor");
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
ops.def("sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()");
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
ops.def(
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
" int slice_offset, int slice_size) -> Tensor");
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
}
REGISTER_EXTENSION(_C)

102
csrc/torch_binding_meta.cpp Normal file
View File

@@ -0,0 +1,102 @@
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/version.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/npu/Module.h>
#include "utils.h"
/*
* How to write a meta implementation for a custom operator (meta kernel):
*
* Meta implementations are used for shape and dtype inference, tracing, and export.
* They do NOT perform any real computation or allocate device memory.
* Instead, they return empty tensors with the correct shapes, dtypes, and device types.
*
* Steps to write a meta implementation:
* 1. The function signature should match the operator's schema, but only use the arguments
* necessary to infer output shapes and dtypes.
* 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes.
* 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype.
* 4. Do NOT perform any real computation or data movement.
* 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar.
*
* Example:
* std::tuple<at::Tensor, at::Tensor> my_op_meta(
* at::Tensor &input, int64_t some_param) {
* // Infer output shape based on input and parameters
* auto out_shape = ...;
* at::Tensor out = at::empty_symint(out_shape, input.options());
* // Return empty tensor(s) with correct shape/dtype
* return {out, ...};
* }
*
* See below for real examples.
*/
namespace vllm_ascend {
namespace meta {
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
at::Tensor &positions,
at::Tensor &query,
at::Tensor &key,
int64_t head_size,
at::Tensor &cos_sin_cache,
bool is_neox) {
auto num_tokens = positions.sym_numel();
auto query_hidden_size = query.sym_numel() / num_tokens;
auto key_hidden_size = key.sym_numel() / num_tokens;
auto num_heads = query_hidden_size / head_size;
auto num_kv_heads = key_hidden_size / head_size;
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
return {query_dst, key_dst};
}
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
at::Tensor &input,
const int64_t org_vocab_start_index,
const int64_t org_vocab_end_index,
const int64_t num_org_vocab_padding,
const int64_t added_vocab_start_index,
const int64_t added_vocab_end_index) {
at::Tensor masked_input = at::empty_like(input);
at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool));
return {masked_input, mask};
}
at::Tensor bgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y,
int64_t slice_offset, int64_t slice_size) {
at::Tensor y_out = at::empty_like(y);
return y_out;
}
at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
at::Tensor &y, int64_t slice_offset, int64_t slice_size) {
at::Tensor y_out = at::empty_like(y);
return y_out;
}
} // namespace meta
} // namespace vllm_ascend
namespace {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
// Rotary embedding meta implementation
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
// Masked input and mask meta implementation
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
// Bgmv expand
ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta);
// Sgmv expand
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
}
}

31
csrc/utils.h Normal file
View File

@@ -0,0 +1,31 @@
#pragma once
#include "kernels/types.h"
#include <c10/core/ScalarType.h>
#include <Python.h>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}