[Bugfix][LoRA][Operator] Fix LoRA custom operators accuracy issue (#2672)

### What this PR does / why we need it?
Fix the LoRA accuracy issue that introduced by custom AscendC operator
"bgmv_shrink, sgmv_shrink, bgmv_expand, sgmv_epand".

The bug details are: 
- In the kernel function, if you want to call GlobalTensor.GetSize
method, you have to pass the second parameter of bufferSize when you
call GlobalTensor.SetGlobalBuffer first.
- Or GlobalTensor.GetSize method will return a random value.
- You can refer to [this
doc](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1alpha002/apiref/ascendcopapi/atlasascendc_api_07_00024.html).

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_ilama_lora.py
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py

- vLLM version: v0.10.1.1
- vLLM main:
a344a5aa0a

---------

Signed-off-by: paulyu12 <paulyu0307@gmail.com>
Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: paulyu12 <paulyu0307@gmail.com>
This commit is contained in:
yupeng
2025-09-02 11:46:59 +08:00
committed by GitHub
parent 214b32a346
commit 9f1e054fe3
9 changed files with 99 additions and 41 deletions

View File

@@ -53,8 +53,8 @@ public:
public:
__aicore__ inline SGMVExpand(AscendC::TPipe* pipe) : pipe_(pipe) {}
__aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* loraIndices,
__gm__ void* seqLen, __gm__ void* yIn, __gm__ void* yOut,
__aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* loraIndices, uint32_t loraIndicesSize,
__gm__ void* seqLen, uint32_t seqLenSize, __gm__ void* yIn, __gm__ void* yOut,
uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
{
@@ -70,8 +70,8 @@ public:
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
yInGm_.SetGlobalBuffer((__gm__ Y_T *)yIn);
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)yOut);
loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices);
seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen);
loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices, loraIndicesSize);
seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen, seqLenSize);
pipe_->InitBuffer(inQueueX_, 1, NUM_ELEMENTS_PER_REPEAT * sizeof(X_T));
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, W_IN_TILE_NUM_ELEMENTS * sizeof(W_T));
@@ -340,7 +340,8 @@ private:
#define SGMV_EXPAND_TYPE_DECLARE(TYPE) \
extern "C" __global__ __aicore__ void sgmv_expand_##TYPE(__gm__ void* x, __gm__ void* weight, \
__gm__ void* loraIndices, __gm__ void* seqLen, \
__gm__ void* loraIndices, uint32_t loraIndicesSize, \
__gm__ void* seqLen, uint32_t seqLenSize, \
__gm__ void* yIn, __gm__ void* yOut, \
uint32_t batchSize, uint32_t numTokensPerCore, \
uint32_t maxLoRARank, uint32_t outputHiddenDim, \
@@ -348,7 +349,8 @@ private:
{ \
AscendC::TPipe pipe; \
SGMVExpand<TYPE> op(&pipe); \
op.Init(x, weight, loraIndices, seqLen, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \
op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, \
yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \
outputHiddenDim, sliceOffset, outputFullDim); \
op.Process(); \
}
@@ -360,18 +362,22 @@ SGMV_EXPAND_TYPE_DECLARE(half)
#endif
namespace vllm_ascend {
extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, void* loraIndices, void* seqLen,
extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weight,
void* loraIndices, uint32_t loraIndicesSize,
void* seqLen, uint32_t seqLenSize,
void* yIn, void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
{
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
if (type == AscendType::FP16) {
sgmv_expand_half<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, seqLen, yIn, yOut, batchSize,
sgmv_expand_half<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize,
yIn, yOut, batchSize,
numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset,
outputFullDim);
} else if (type == AscendType::BF16) {
#if (__CCE_AICORE__ >= 220)
sgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, seqLen, yIn, yOut, batchSize,
sgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize,
seqLen, seqLenSize, yIn, yOut, batchSize,
numTokensPerCore, maxLoRARank, outputHiddenDim,
sliceOffset, outputFullDim);
#endif