From ea8f544ce73708aef8c9d48a2916b64fa0e09806 Mon Sep 17 00:00:00 2001 From: hukongyi Date: Fri, 19 Dec 2025 14:22:06 +0800 Subject: [PATCH] [BugFix]Fix precision issue for LoRA feature (#4141) vLLM version: v0.11.0 vLLM main: vllm-project/vllm ### What this PR does / why we need it? Fix the precision issue of the LoRA feature in vllm-ascend. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? ```bash pytest tests/lora/test_llama_tp.py::test_llama_lora -s ``` lora_test - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: hukongyi --- CMakeLists.txt | 4 ++++ csrc/kernels/bgmv_expand.cpp | 6 +++--- csrc/kernels/bgmv_shrink.cpp | 6 +++--- csrc/kernels/sgmv_expand.cpp | 6 +++--- csrc/kernels/sgmv_shrink.cpp | 6 +++--- vllm_ascend/lora/punica_npu.py | 1 + 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b72a7eb0..6c55d93c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,10 @@ set(VLLM_ASCEND_CUSTOM_OP ) set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE + ${KERNEL_FILES}/bgmv_expand.cpp + ${KERNEL_FILES}/bgmv_shrink.cpp + ${KERNEL_FILES}/sgmv_expand.cpp + ${KERNEL_FILES}/sgmv_shrink.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp ) diff --git a/csrc/kernels/bgmv_expand.cpp b/csrc/kernels/bgmv_expand.cpp index c910005c..8afcd92d 100644 --- a/csrc/kernels/bgmv_expand.cpp +++ b/csrc/kernels/bgmv_expand.cpp @@ -342,7 +342,7 @@ private: // declare all dtype kernel BGMV_EXPAND_TYPE_DECLARE(half) -#if (__CCE_AICORE__ >= 220) +#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220) BGMV_EXPAND_TYPE_DECLARE(bfloat16_t) #endif @@ -356,8 +356,8 @@ extern void bgmv_expand_impl(AscendType type, void* stream, void* x, void* weigh bgmv_expand_half<<>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); } else if (type == AscendType::BF16) { - #if (__CCE_AICORE__ >= 220) - bgmv_expand_bfloat16_t<<>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, + #if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220) + bgmv_expand_bfloat16_t<<>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); #endif diff --git a/csrc/kernels/bgmv_shrink.cpp b/csrc/kernels/bgmv_shrink.cpp index b5a2d15d..0dea13e8 100644 --- a/csrc/kernels/bgmv_shrink.cpp +++ b/csrc/kernels/bgmv_shrink.cpp @@ -226,7 +226,7 @@ private: // declare all dtype kernel BGMV_SHRINK_TYPE_DECLARE(half) -#if (__CCE_AICORE__ >= 220) +#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220) BGMV_SHRINK_TYPE_DECLARE(bfloat16_t) #endif @@ -240,8 +240,8 @@ extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weigh bgmv_shrink_half<<>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); } else if (type == AscendType::BF16) { - #if (__CCE_AICORE__ >= 220) - bgmv_shrink_bfloat16_t<<>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, + #if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220) + bgmv_shrink_bfloat16_t<<>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); #endif } else { diff --git a/csrc/kernels/sgmv_expand.cpp b/csrc/kernels/sgmv_expand.cpp index 5466bd69..65c32f96 100644 --- a/csrc/kernels/sgmv_expand.cpp +++ b/csrc/kernels/sgmv_expand.cpp @@ -357,7 +357,7 @@ private: // declare all dtype kernel SGMV_EXPAND_TYPE_DECLARE(half) -#if (__CCE_AICORE__ >= 220) +#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220) SGMV_EXPAND_TYPE_DECLARE(bfloat16_t) #endif @@ -375,8 +375,8 @@ extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weigh numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); } else if (type == AscendType::BF16) { - #if (__CCE_AICORE__ >= 220) - sgmv_expand_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, + #if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220) + sgmv_expand_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); diff --git a/csrc/kernels/sgmv_shrink.cpp b/csrc/kernels/sgmv_shrink.cpp index a72e592e..49f21e0b 100644 --- a/csrc/kernels/sgmv_shrink.cpp +++ b/csrc/kernels/sgmv_shrink.cpp @@ -242,7 +242,7 @@ private: // declare all dtype kernel SGMV_SHRINK_TYPE_DECLARE(half) -#if (__CCE_AICORE__ >= 220) +#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220) SGMV_SHRINK_TYPE_DECLARE(bfloat16_t) #endif @@ -260,8 +260,8 @@ extern void sgmv_shrink_impl(AscendType type, void* stream, void* x, void* weigh numTokensPerCore, inputHiddenDim, maxLoRARank, scale); } else if (type == AscendType::BF16) { - #if (__CCE_AICORE__ >= 220) - sgmv_shrink_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, + #if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220) + sgmv_shrink_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, diff --git a/vllm_ascend/lora/punica_npu.py b/vllm_ascend/lora/punica_npu.py index 3dba7ee9..ddc98e32 100644 --- a/vllm_ascend/lora/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -255,6 +255,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): # Embedding layer only need expand op expand_fun: Callable = (self._expand_prefill if self.is_prefill else self._expand_decode) + x = x.to(torch.float32) expand_fun(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self,