Add Custom Kernels For LoRA Performance (#1884)

### What this PR does / why we need it?
Add two custom kernels(bgmv_shrink and bgmv expand) to solve the
performance of LoRA
### Does this PR introduce _any_ user-facing change?
no user-facing change
### How was this patch tested?
we add Unit Test file to test the custom ascendc kernel. See
vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py and
vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, the TTFT, TPOT and throughput have increased by
about 70%.

- vLLM version: v0.9.2
- vLLM main:
40d86ee412

---------

Signed-off-by: taoxudonghaha <justsheldon@163.com>
This commit is contained in:
taoxudonghaha
2025-07-29 19:27:50 +08:00
committed by GitHub
parent 2da281ec5a
commit 540336edc9
8 changed files with 946 additions and 3 deletions

View File

@@ -0,0 +1,369 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include "types.h"
template <typename scalar_t>
class BGMVExpand {
public:
using X_T = float;
using W_T = scalar_t;
using Y_T = scalar_t;
static constexpr uint64_t LORA_RANK_8 = 8;
static constexpr uint64_t LORA_RANK_16 = 16;
static constexpr uint64_t LORA_RANK_32 = 32;
static constexpr uint64_t LORA_RANK_64 = 64;
static constexpr uint64_t SUPPORTED_RANKS[] = {LORA_RANK_8, LORA_RANK_16, LORA_RANK_32, LORA_RANK_64};
static constexpr int32_t BUFFER_NUM = 2;
// The vector unit reads 8 blocks (32 bytes each and 256 bytes in total) of contiguous data each time.
static constexpr int32_t NUM_BYTES_PER_REPEAT = 256;
static constexpr int32_t NUM_BLOCKS_PER_REPEAT = 8;
// The maximum number of elements in a single iteration is 256 / sizeof(intermediate data type).
static constexpr int32_t NUM_ELEMENTS_PER_REPEAT = NUM_BYTES_PER_REPEAT / sizeof(float);
// Mask is used to control the elements that participate in computation in each iteration.
static constexpr int32_t MASK_COUNT = NUM_BYTES_PER_REPEAT / sizeof(float);
// Refer to numOutputElementsPerInputTile_ initialization for the constraints on the following constants.
static constexpr int32_t W_IN_TILE_NUM_ELEMENTS = 8192;
static constexpr int32_t Y_OUT_TILE_NUM_ELEMENTS = 4096;
static constexpr int32_t BLOCK_REDUCE_NUM_REPEATS = W_IN_TILE_NUM_ELEMENTS / NUM_ELEMENTS_PER_REPEAT;
// BlockReduceSum would generate(BLOCK_REDUCE_NUM_REPEATS * NUM_BLOCKS_PER_REPEAT)floats.
// So need to read them all and apply PairReduceSum
static constexpr int32_t PAIR_REDUCE_NUM_REPEATS_16 =
(BLOCK_REDUCE_NUM_REPEATS * NUM_BLOCKS_PER_REPEAT + NUM_ELEMENTS_PER_REPEAT - 1) / NUM_ELEMENTS_PER_REPEAT;
// The second PairReduceSum for rank=32, needs half of the repetition that happened for rank=16.
// Same for rank=64, we do not support ranks greater than 64.
static constexpr int32_t PAIR_REDUCE_NUM_REPEATS_32 = (PAIR_REDUCE_NUM_REPEATS_16 + 1) / 2;
public:
__aicore__ inline BGMVExpand(AscendC::TPipe* pipe) : pipe_(pipe) {}
__aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,
__gm__ void* yIn, __gm__ void* yOut,
uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
{
batchSize_ = batchSize;
numTokensPerCore_ = numTokensPerCore;
maxLoRARank_ = maxLoRARank;
outputHiddenDim_ = outputHiddenDim;
sliceOffset_ = sliceOffset;
outputFullDim_ = outputFullDim;
singleLoRAWeightLen_ = maxLoRARank_ * outputHiddenDim_;
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
yInGm_.SetGlobalBuffer((__gm__ Y_T *)yIn);
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)yOut);
indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices);
pipe_->InitBuffer(inQueueX_, 1, NUM_ELEMENTS_PER_REPEAT * sizeof(X_T));
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, W_IN_TILE_NUM_ELEMENTS * sizeof(W_T));
pipe_->InitBuffer(inQueueY_, BUFFER_NUM, Y_OUT_TILE_NUM_ELEMENTS * sizeof(Y_T));
pipe_->InitBuffer(outQueueY_, BUFFER_NUM, Y_OUT_TILE_NUM_ELEMENTS * sizeof(Y_T));
pipe_->InitBuffer(dupBufferX_, NUM_ELEMENTS_PER_REPEAT * sizeof(float));
pipe_->InitBuffer(tmpBufferW_, W_IN_TILE_NUM_ELEMENTS * sizeof(float));
pipe_->InitBuffer(inBufferY_, Y_OUT_TILE_NUM_ELEMENTS * sizeof(float));
pipe_->InitBuffer(tmpBufferY_, Y_OUT_TILE_NUM_ELEMENTS * sizeof(float));
// Each compute iteration would generate not one, but several output elements.
// Therefore, the following variable would determine how many output elements are calculated in each iteration.
numOutputElementsPerInputTile_ = BLOCK_REDUCE_NUM_REPEATS * (NUM_ELEMENTS_PER_REPEAT / maxLoRARank_);
numStreamInPerOutputTile_ = Y_OUT_TILE_NUM_ELEMENTS / numOutputElementsPerInputTile_;
}
__aicore__ inline void Process()
{
int64_t blockIdx = AscendC::GetBlockIdx();
int64_t startIdx = blockIdx * numTokensPerCore_;
int64_t endIdx = startIdx + numTokensPerCore_;
if (endIdx > batchSize_) {
endIdx = batchSize_;
}
for (int64_t idx = startIdx; idx < endIdx; idx++) {
yOffset_ = outputFullDim_ * idx + sliceOffset_;
// Set up LoRA index
CopyInIndex(idx);
if (reqLoRAIndex_ < 0) {
continue;
}
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
CopyInX(idx);
int32_t numStreamOut = outputHiddenDim_ / Y_OUT_TILE_NUM_ELEMENTS;
for (int32_t i = 0; i < numStreamOut; i++) {
CopyInY(i);
for (int32_t j = 0; j < numStreamInPerOutputTile_; j++) {
CopyInW(i * numStreamInPerOutputTile_ + j);
Compute(j * numOutputElementsPerInputTile_);
}
ScaleOutput();
CopyOut(i);
}
ComputeLastIteration();
}
}
private:
__aicore__ inline void CopyInIndex(const int64_t idx)
{
// Look up the LoRA index
reqLoRAIndex_ = indicesGm_.GetValue(idx);
}
__aicore__ inline void ComputeLastIteration()
{
int32_t remainingY = outputHiddenDim_ % Y_OUT_TILE_NUM_ELEMENTS;
if (remainingY == 0) {
return;
}
int32_t numStreamOut = outputHiddenDim_ / Y_OUT_TILE_NUM_ELEMENTS;
int32_t remainingW = remainingY * maxLoRARank_;
int32_t numCompleteWTileInForLastIteration = remainingW / W_IN_TILE_NUM_ELEMENTS;
int32_t remainingWForLastRepeat = remainingW % W_IN_TILE_NUM_ELEMENTS;
CopyInY(numStreamOut, remainingY);
int32_t outputIdx = 0;
for (outputIdx = 0; outputIdx < numCompleteWTileInForLastIteration; outputIdx++) {
CopyInW(numStreamOut * numStreamInPerOutputTile_ + outputIdx);
Compute(outputIdx * numOutputElementsPerInputTile_);
}
if (remainingWForLastRepeat != 0) {
CopyInW(numStreamOut * numStreamInPerOutputTile_ + numCompleteWTileInForLastIteration,
remainingWForLastRepeat);
int32_t lastRepeatCount = remainingWForLastRepeat / NUM_ELEMENTS_PER_REPEAT;
int32_t pairReduceRepeat16 =
(lastRepeatCount * NUM_BLOCKS_PER_REPEAT + NUM_ELEMENTS_PER_REPEAT - 1) / NUM_ELEMENTS_PER_REPEAT;
int32_t pairReduceRepeat32 = (pairReduceRepeat16 + 1) / 2;
int32_t lastComputeOutputElement = outputIdx * numOutputElementsPerInputTile_;
Compute(lastComputeOutputElement, lastRepeatCount, pairReduceRepeat16, pairReduceRepeat32);
}
ScaleOutput(remainingY);
CopyOut(numStreamOut, remainingY);
}
__aicore__ inline void CopyInX(const int64_t idx)
{
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
if constexpr (std::is_same_v<X_T, float>) {
DataCopy(xLocal, xGm_[maxLoRARank_ * idx], maxLoRARank_);
} else {
uint16_t blockLen = static_cast<uint16_t>(maxLoRARank_ * sizeof(X_T));
DataCopyPad(xLocal, xGm_[maxLoRARank_ * idx], {1, blockLen, 0, 0}, {});
}
inQueueX_.EnQue(xLocal);
xLocal = inQueueX_.DeQue<X_T>();
AscendC::LocalTensor<float> xDup = dupBufferX_.Get<float>();
// As we are generating multiple output elements with one API invocation,
// we need to duplicate the X vector multiple times to fill one NUM_BYTES_PER_REPEAT
if constexpr (std::is_same_v<X_T, float>) {
for (int32_t i = 0; i < NUM_ELEMENTS_PER_REPEAT; i += maxLoRARank_) {
for (int32_t j = 0; j < maxLoRARank_; j++) {
float entry = xLocal.GetValue(j);
xDup.SetValue(i + j, entry);
}
}
} else {
Cast(xDup, xLocal, AscendC::RoundMode::CAST_NONE, maxLoRARank_);
pipe_barrier(PIPE_V);
for (int32_t i = maxLoRARank_; i < NUM_ELEMENTS_PER_REPEAT; i += maxLoRARank_) {
for (int32_t j = 0; j < maxLoRARank_; j++) {
float entry = xDup.GetValue(j);
xDup.SetValue(i + j, entry);
}
}
}
inQueueX_.FreeTensor(xLocal);
}
__aicore__ inline void CopyInY(int32_t progress, int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<Y_T> yInLocal = inQueueY_.AllocTensor<Y_T>();
DataCopy(yInLocal, yInGm_[yOffset_ + progress * Y_OUT_TILE_NUM_ELEMENTS], numElements);
inQueueY_.EnQue(yInLocal);
}
__aicore__ inline void CopyInW(int32_t progress, int32_t numElements = W_IN_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + progress * W_IN_TILE_NUM_ELEMENTS], numElements);
inQueueW_.EnQue(wLocal);
}
__aicore__ inline void ScaleOutput(int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<float> yLocal = tmpBufferY_.Get<float>();
AscendC::LocalTensor<Y_T> yInLocal = inQueueY_.DeQue<Y_T>();
AscendC::LocalTensor<float> yInLocalFP32 = inBufferY_.Get<float>();
Cast(yInLocalFP32, yInLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueY_.FreeTensor(yInLocal);
Add(yLocal, yLocal, yInLocalFP32, numElements);
pipe_barrier(PIPE_V);
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
Cast(yOutLocal, yLocal, AscendC::RoundMode::CAST_RINT, numElements);
pipe_barrier(PIPE_V);
outQueueY_.EnQue<Y_T>(yOutLocal);
}
__aicore__ inline void Compute(int32_t progress,
int32_t blockReduceRepeatCount=BLOCK_REDUCE_NUM_REPEATS,
int32_t pairReduceRepeat16=PAIR_REDUCE_NUM_REPEATS_16,
int32_t pairReduceRepeat32=PAIR_REDUCE_NUM_REPEATS_32)
{
AscendC::LocalTensor<float> yLocal = tmpBufferY_.Get<float>();
AscendC::LocalTensor<float> xDup = dupBufferX_.Get<float>();
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, MASK_COUNT, blockReduceRepeatCount, castParams_);
pipe_barrier(PIPE_V);
inQueueW_.FreeTensor(wLocal);
Mul(wTmpTensor, xDup, wTmpTensor, MASK_COUNT, blockReduceRepeatCount, dotProductParams_);
pipe_barrier(PIPE_V);
if (maxLoRARank_ == LORA_RANK_8) {
BlockReduceSum(yLocal[progress], wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_16) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_32) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(wTmpTensor, wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
PairReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat32, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
} else if (maxLoRARank_ == LORA_RANK_64) {
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
BlockReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat16, MASK_COUNT,
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
pipe_barrier(PIPE_V);
}
}
__aicore__ inline void CopyOut(int32_t progress, int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
{
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
DataCopy(yOutGm_[yOffset_ + progress * Y_OUT_TILE_NUM_ELEMENTS], yOutLocal, numElements);
outQueueY_.FreeTensor(yOutLocal);
}
private:
AscendC::TPipe* pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueY_, inQueueW_;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueX_;
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> outQueueY_;
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferW_, dupBufferX_, inBufferY_, tmpBufferY_;
AscendC::GlobalTensor<X_T> xGm_;
AscendC::GlobalTensor<W_T> wGm_;
AscendC::GlobalTensor<Y_T> yInGm_;
AscendC::GlobalTensor<Y_T> yOutGm_;
AscendC::GlobalTensor<int64_t> indicesGm_;
uint32_t batchSize_;
uint32_t numTokensPerCore_;
uint32_t maxLoRARank_;
uint32_t outputHiddenDim_;
uint32_t sliceOffset_;
uint32_t outputFullDim_;
uint32_t singleLoRAWeightLen_;
int64_t reqLoRAIndex_;
uint64_t reqLoRAWeightOffset_;
uint32_t numOutputElementsPerInputTile_;
uint32_t numStreamInPerOutputTile_;
uint64_t yOffset_;
// The block stride is set to 1, and 8 blocks in the same repeat are processed continuously.
// The repeat stride is 8, so the vector unit reads 8 consecutive blocks in the first repeat,
// reads next 8 consecutive blocks in the second repeat.
AscendC::UnaryRepeatParams castParams_ = {1, 1, 8, 4};
// For each repeat in BlockReduceSum and PairReduceSum we should move forward only one block,
// so we set dstRepStride = 1
AscendC::UnaryRepeatParams reduceSumParams_ = {1, 1, 1, 8};
// When the repeat stride is 0, the vector unit repeatedly reads and computes the first 8 consecutive blocks.
// For xDup we repeatedly use it, so we set src0RepStride = 0
AscendC::BinaryRepeatParams dotProductParams_ = {1, 1, 1, 8, 0, 8};
};
#define BGMV_EXPAND_TYPE_DECLARE(TYPE) \
extern "C" __global__ __aicore__ void bgmv_expand_##TYPE(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,\
__gm__ void* yIn, __gm__ void* yOut, \
uint32_t batchSize, uint32_t numTokensPerCore, \
uint32_t maxLoRARank, uint32_t outputHiddenDim, \
uint32_t sliceOffset, uint32_t outputFullDim) \
{ \
AscendC::TPipe pipe; \
BGMVExpand<TYPE> op(&pipe); \
op.Init(x, weight, indices, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \
outputHiddenDim, sliceOffset, outputFullDim); \
op.Process(); \
}
// declare all dtype kernel
BGMV_EXPAND_TYPE_DECLARE(half)
#if (__CCE_AICORE__ >= 220)
BGMV_EXPAND_TYPE_DECLARE(bfloat16_t)
#endif
namespace vllm_ascend {
extern void bgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, void* indices,
void* yIn, void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
{
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
if (type == AscendType::FP16) {
bgmv_expand_half<<<blockDim, nullptr, stream>>>(x, weight, indices, yIn, yOut, batchSize, numTokensPerCore,
maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim);
} else if (type == AscendType::BF16) {
#if (__CCE_AICORE__ >= 220)
bgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, yIn, yOut, batchSize,
numTokensPerCore, maxLoRARank, outputHiddenDim,
sliceOffset, outputFullDim);
#endif
} else {
return;
}
}
} // namespace vllm_ascend

View File

@@ -0,0 +1,252 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include "types.h"
template <typename scalar_t>
class BGMVShrink {
public:
using X_T = scalar_t;
using W_T = scalar_t;
using Y_T = float;
static constexpr uint64_t BUFFER_NUM = 1;
static constexpr uint64_t TILE_LENGTH = 11776; // optimal performance tile length
public:
__aicore__ inline BGMVShrink(AscendC::TPipe *pipe) : pipe_(pipe) {}
__aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *indices, __gm__ void *y,
uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
uint32_t maxLoRARank, float scale)
{
batchSize_ = batchSize;
numTokensPerCore_ = numTokensPerCore;
inputHiddenDim_ = inputHiddenDim;
maxLoRARank_ = maxLoRARank;
scale_ = scale;
singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_;
incremental_ = inputHiddenDim_ > TILE_LENGTH;
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)y);
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices);
pipe_->InitBuffer(inQueueX_, BUFFER_NUM, TILE_LENGTH * sizeof(X_T));
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, TILE_LENGTH * sizeof(W_T));
pipe_->InitBuffer(tmpBufferX_, TILE_LENGTH * sizeof(float));
pipe_->InitBuffer(tmpBufferW_, TILE_LENGTH * sizeof(float));
pipe_->InitBuffer(outQueueY_, 1, maxLoRARank_ * sizeof(Y_T));
pipe_->InitBuffer(outBufferY_, maxLoRARank_ * sizeof(float));
}
__aicore__ inline void Process()
{
int64_t blockIdx = AscendC::GetBlockIdx();
int64_t startIdx = blockIdx * numTokensPerCore_;
int64_t endIdx = startIdx + numTokensPerCore_;
if (endIdx > batchSize_) {
endIdx = batchSize_;
}
for (int64_t idx = startIdx; idx < endIdx; idx++) {
// set up LoRA index
CopyInIndex(idx);
if (reqLoRAIndex_ < 0) {
continue;
}
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
if (incremental_) {
ProcessImpl<true>(idx);
} else {
ProcessImpl<false>(idx);
}
ScaleOutput();
CopyOut(idx);
}
}
private:
template <bool INCREMENTAL_MODE>
__aicore__ inline void ProcessImpl(const int64_t idx)
{
AscendC::LocalTensor<float> yOutLocal = outBufferY_.Get<float>();
if constexpr (!INCREMENTAL_MODE) {
CopyInX(idx, 0, inputHiddenDim_);
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, inputHiddenDim_);
pipe_barrier(PIPE_V);
inQueueX_.FreeTensor(xLocal);
}
for (int i = 0; i < maxLoRARank_; i++) {
float acc(0);
for (int32_t j = 0; j < inputHiddenDim_ / TILE_LENGTH; j++) {
if constexpr (INCREMENTAL_MODE) {
CopyInX(idx, j);
}
CopyInW(i, j);
Compute<INCREMENTAL_MODE>(acc);
}
CopyAndComputeLastIteration<INCREMENTAL_MODE>(idx, i, acc);
yOutLocal.SetValue(i, acc);
}
}
__aicore__ inline void CopyInIndex(const int64_t idx)
{
// look up the LoRA index
reqLoRAIndex_ = indicesGm_.GetValue(idx);
}
__aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
DataCopy(xLocal, xGm_[inputHiddenDim_ * idx + colIdx * TILE_LENGTH], numElements);
inQueueX_.EnQue(xLocal);
}
__aicore__ inline void CopyInW(int32_t rowIdx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + rowIdx * inputHiddenDim_ + colIdx * TILE_LENGTH], numElements);
inQueueW_.EnQue(wLocal);
}
template <bool INCREMENTAL_MODE>
__aicore__ inline void Compute(float &acc, int32_t numElements = TILE_LENGTH)
{
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
if constexpr (INCREMENTAL_MODE) {
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, numElements);
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueX_.FreeTensor(xLocal);
inQueueW_.FreeTensor(wLocal);
} else {
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
pipe_barrier(PIPE_V);
inQueueW_.FreeTensor(wLocal);
}
// dot product of the one tile of X and W
Mul(wTmpTensor, xTmpTensor, wTmpTensor, numElements);
pipe_barrier(PIPE_V);
// reduce sum generate one number, which is the summation of all the dot product
ReduceSum<float>(wTmpTensor, wTmpTensor, wTmpTensor, numElements);
pipe_barrier(PIPE_V);
acc += wTmpTensor.GetValue(0);
}
template <bool INCREMENTAL_MODE>
__aicore__ inline void CopyAndComputeLastIteration(const int64_t idx, int32_t rowIdx, float &acc)
{
int32_t colIdx = inputHiddenDim_ / TILE_LENGTH;
int32_t remaining = inputHiddenDim_ % TILE_LENGTH;
if (remaining == 0) {
return;
}
if constexpr (INCREMENTAL_MODE) {
CopyInX(idx, colIdx, remaining);
}
CopyInW(rowIdx, colIdx, remaining);
Compute<INCREMENTAL_MODE>(acc, remaining);
}
__aicore__ inline void ScaleOutput()
{
AscendC::LocalTensor<float> yLocal = outBufferY_.Get<float>();
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
Muls(yOutLocal, yLocal, scale_, maxLoRARank_);
pipe_barrier(PIPE_V);
outQueueY_.EnQue<Y_T>(yOutLocal);
}
__aicore__ inline void CopyOut(const int64_t idx)
{
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
DataCopy(yOutGm_[maxLoRARank_ * idx], yOutLocal, maxLoRARank_);
outQueueY_.FreeTensor(yOutLocal);
}
private:
AscendC::TPipe *pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueX_, inQueueW_;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueY_;
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferX_, tmpBufferW_, outBufferY_;
AscendC::GlobalTensor<X_T> xGm_;
AscendC::GlobalTensor<W_T> wGm_;
AscendC::GlobalTensor<int64_t> indicesGm_;
AscendC::GlobalTensor<Y_T> yOutGm_;
uint32_t batchSize_;
uint32_t numTokensPerCore_;
uint32_t inputHiddenDim_;
uint32_t maxLoRARank_;
float scale_;
uint32_t singleLoRAWeightLen_;
int64_t reqLoRAIndex_;
uint64_t reqLoRAWeightOffset_;
bool incremental_;
};
#define BGMV_SHRINK_TYPE_DECLARE(TYPE) \
extern "C" __global__ __aicore__ void bgmv_shrink_##TYPE(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,\
__gm__ void* y, uint32_t batchSize, \
uint32_t numTokensPerCore, uint32_t inputHiddenDim, \
uint32_t maxLoRARank, float scale) \
{ \
AscendC::TPipe pipe; \
BGMVShrink<TYPE> op(&pipe); \
op.Init(x, weight, indices, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \
op.Process(); \
}
// declare all dtype kernel
BGMV_SHRINK_TYPE_DECLARE(half)
#if (__CCE_AICORE__ >= 220)
BGMV_SHRINK_TYPE_DECLARE(bfloat16_t)
#endif
namespace vllm_ascend {
extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, void* indices,
void* y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
uint32_t maxLoRARank, float scale)
{
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
if (type == AscendType::FP16) {
bgmv_shrink_half<<<blockDim, nullptr, stream>>>(x, weight, indices, y, batchSize, numTokensPerCore,
inputHiddenDim, maxLoRARank, scale);
} else if (type == AscendType::BF16) {
#if (__CCE_AICORE__ >= 220)
bgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, y, batchSize, numTokensPerCore,
inputHiddenDim, maxLoRARank, scale);
#endif
} else {
return;
}
}
} // namespace vllm_ascend

View File

@@ -60,4 +60,32 @@ namespace vllm_ascend {
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
return new_tensor;
}
extern void bgmv_shrink_impl(
AscendType type,
void *stream,
void *x,
void *weight,
void *indices,
void *y,
uint32_t batch_size,
uint32_t num_tokens_per_core,
uint32_t input_hidden_dim,
uint32_t lora_rank,
float scale);
extern void bgmv_expand_impl(
AscendType type,
void *stream,
void *x,
void *weight,
void *indices,
void *y,
void *y_out,
uint32_t batch_size,
uint32_t num_tokens_per_core,
uint32_t lora_rank,
uint32_t output_hidden_dim,
uint32_t slice_offset,
uint32_t output_full_dim);
}

View File

@@ -199,6 +199,90 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
cmd.Run();
return {masked_input, mask};
}
void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, double scale)
{
at::ScalarType scalar_type = x.scalar_type();
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
"the first dimension of x, y, indices should be same");
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
void* x_ptr = x.data_ptr();
void* weight_ptr = weight.data_ptr();
void* indices_ptr = indices.data_ptr();
void* y_ptr = y.data_ptr();
int batch_size = x.size(0);
int input_hidden_token = x.size(1);
uint32_t lora_rank = y.size(1);
float scale_f = static_cast<float>(scale);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("bgmv_shrink");
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, input_hidden_token,
lora_rank, scale_f]() -> int {
auto dtype = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, num_tokens_per_core,
input_hidden_token, lora_rank, scale_f);
return 0;
});
cmd.Run();
return;
}
at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y,
int64_t slice_offset, int64_t slice_size)
{
at::ScalarType scalar_type = y.scalar_type();
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
"the first dimension of x, y, indices should be same");
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
"slice_size + slice_offset should be smaller than the second dimension of y")
at::Tensor y_out = y;
void* x_ptr = x.data_ptr();
void* weight_ptr = weight.data_ptr();
void* indices_ptr = indices.data_ptr();
void* y_ptr = y.data_ptr();
void* y_out_ptr = y_out.data_ptr();
int batch_size = x.size(0);
int lora_rank = x.size(1);
int output_full_dim = y.size(1);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("bgmv_expand");
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size, lora_rank,
slice_offset, slice_size, output_full_dim]() -> int {
auto dtype = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size,
num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
return 0;
});
cmd.Run();
return y_out;
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(_C, ops)
@@ -223,6 +307,14 @@ TORCH_LIBRARY_EXPAND(_C, ops)
" int added_vocab_start_index, "
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
ops.def(
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
" int slice_offset, int slice_size) -> Tensor");
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
}
REGISTER_EXTENSION(_C)

View File

@@ -0,0 +1,41 @@
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def bgmv_expand_cpu_impl(x: torch.Tensor, w: torch.Tensor,
indices: torch.Tensor, y: torch.tensor,
slice_offset: int, slice_size: int) -> torch.Tensor:
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
y[:, slice_offset:slice_offset + slice_size] += z
return y
@torch.inference_mode()
def test_bgmv_expand() -> None:
B = 1
x = torch.randn([B, 16], dtype=torch.float)
w = torch.randn([64, 128, 16], dtype=torch.float16)
indices = torch.zeros([B], dtype=torch.int64)
y = torch.randn([B, 128 * 3], dtype=torch.float16)
x_npu = x.npu()
w_npu = w.npu()
indices_npu = indices.npu()
y_npu = y.npu()
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0,
128)
# Compare the results.
torch.testing.assert_close(y_out_npu.cpu(),
y_out,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

View File

@@ -0,0 +1,40 @@
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def bgmv_shrink_cpu_impl(x: torch.Tensor, w: torch.Tensor,
indices: torch.Tensor, y: torch.tensor,
scaling: float) -> torch.Tensor:
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
y[:, :] += z * scaling
return y
@torch.inference_mode()
def test_bgmv_shrink() -> None:
B = 1
x = torch.randn([B, 128], dtype=torch.float16)
w = torch.randn([64, 16, 128], dtype=torch.float16)
indices = torch.zeros([B], dtype=torch.int64)
y = torch.zeros([B, 16])
x_npu = x.npu()
w_npu = w.npu()
indices_npu = indices.npu()
y_npu = y.npu()
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
# Compare the results.
torch.testing.assert_close(y_npu.cpu(),
y,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

View File

@@ -0,0 +1,112 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def bgmv_shrink(inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
return torch.ops._C.bgmv_shrink(
inputs,
lora_a_weights,
lora_indices_tensor,
output_tensor,
scaling,
)
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
return torch.ops._C.bgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
output_tensor,
0,
output_tensor.size(1),
)
def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, output_tensor,
slice_offset, slice_size)
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
scaling)
def sgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
add_inputs)
def sgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
slice_offset, slice_size, add_inputs)

View File

@@ -3,9 +3,18 @@
from typing import Callable, Optional, Tuple, Union
import torch
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm_ascend.utils import is_310p
if is_310p():
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
else:
from vllm_ascend.lora.punica_wrapper.lora_ops import (
bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase