diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 162af5c..a704833 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -201,7 +201,7 @@ jobs: pytest -sv tests/e2e/singlecard/test_embedding.py pytest -sv tests/e2e/singlecard/test_guided_decoding.py # TODO: Fix lora accuracy error - # pytest -sv tests/e2e/singlecard/test_ilama_lora.py + pytest -sv tests/e2e/singlecard/test_ilama_lora.py pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py pytest -sv tests/e2e/singlecard/test_quantization.py pytest -sv tests/e2e/singlecard/test_sampler.py @@ -280,7 +280,7 @@ jobs: # external_launcher test is not stable enough. Fix it later # pytest -sv tests/e2e/multicard/test_external_launcher.py pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py - # pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py + pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py # To avoid oom, we need to run the test in a single process. pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ diff --git a/csrc/kernels/bgmv_expand.cpp b/csrc/kernels/bgmv_expand.cpp index 84a4f09..c910005 100644 --- a/csrc/kernels/bgmv_expand.cpp +++ b/csrc/kernels/bgmv_expand.cpp @@ -54,7 +54,7 @@ public: __aicore__ inline BGMVExpand(AscendC::TPipe* pipe) : pipe_(pipe) {} __aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* indices, - __gm__ void* yIn, __gm__ void* yOut, + uint32_t indicesSize, __gm__ void* yIn, __gm__ void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank, uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim) { @@ -70,7 +70,7 @@ public: wGm_.SetGlobalBuffer((__gm__ W_T *)weight); yInGm_.SetGlobalBuffer((__gm__ Y_T *)yIn); yOutGm_.SetGlobalBuffer((__gm__ Y_T *)yOut); - indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices); + indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices, indicesSize); pipe_->InitBuffer(inQueueX_, 1, NUM_ELEMENTS_PER_REPEAT * sizeof(X_T)); pipe_->InitBuffer(inQueueW_, BUFFER_NUM, W_IN_TILE_NUM_ELEMENTS * sizeof(W_T)); @@ -328,14 +328,14 @@ private: #define BGMV_EXPAND_TYPE_DECLARE(TYPE) \ extern "C" __global__ __aicore__ void bgmv_expand_##TYPE(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,\ - __gm__ void* yIn, __gm__ void* yOut, \ + uint32_t indicesSize, __gm__ void* yIn, __gm__ void* yOut,\ uint32_t batchSize, uint32_t numTokensPerCore, \ uint32_t maxLoRARank, uint32_t outputHiddenDim, \ uint32_t sliceOffset, uint32_t outputFullDim) \ { \ AscendC::TPipe pipe; \ BGMVExpand op(&pipe); \ - op.Init(x, weight, indices, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \ + op.Init(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \ outputHiddenDim, sliceOffset, outputFullDim); \ op.Process(); \ } @@ -347,17 +347,17 @@ BGMV_EXPAND_TYPE_DECLARE(half) #endif namespace vllm_ascend { -extern void bgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, void* indices, +extern void bgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, void* indices, uint32_t indicesSize, void* yIn, void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank, uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim) { uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore; if (type == AscendType::FP16) { - bgmv_expand_half<<>>(x, weight, indices, yIn, yOut, batchSize, numTokensPerCore, + bgmv_expand_half<<>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); } else if (type == AscendType::BF16) { #if (__CCE_AICORE__ >= 220) - bgmv_expand_bfloat16_t<<>>(x, weight, indices, yIn, yOut, batchSize, + bgmv_expand_bfloat16_t<<>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); #endif diff --git a/csrc/kernels/bgmv_shrink.cpp b/csrc/kernels/bgmv_shrink.cpp index ae73eb7..b5a2d15 100644 --- a/csrc/kernels/bgmv_shrink.cpp +++ b/csrc/kernels/bgmv_shrink.cpp @@ -29,7 +29,7 @@ public: public: __aicore__ inline BGMVShrink(AscendC::TPipe *pipe) : pipe_(pipe) {} - __aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *indices, __gm__ void *y, + __aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *indices, uint32_t indicesSize, __gm__ void *y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim, uint32_t maxLoRARank, float scale) { @@ -44,7 +44,7 @@ public: xGm_.SetGlobalBuffer((__gm__ X_T *)x); yOutGm_.SetGlobalBuffer((__gm__ Y_T *)y); wGm_.SetGlobalBuffer((__gm__ W_T *)weight); - indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices); + indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices, indicesSize); pipe_->InitBuffer(inQueueX_, BUFFER_NUM, TILE_LENGTH * sizeof(X_T)); pipe_->InitBuffer(inQueueW_, BUFFER_NUM, TILE_LENGTH * sizeof(W_T)); @@ -214,13 +214,13 @@ private: #define BGMV_SHRINK_TYPE_DECLARE(TYPE) \ extern "C" __global__ __aicore__ void bgmv_shrink_##TYPE(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,\ - __gm__ void* y, uint32_t batchSize, \ + uint32_t indicesSize, __gm__ void* y, uint32_t batchSize, \ uint32_t numTokensPerCore, uint32_t inputHiddenDim, \ uint32_t maxLoRARank, float scale) \ { \ AscendC::TPipe pipe; \ BGMVShrink op(&pipe); \ - op.Init(x, weight, indices, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \ + op.Init(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \ op.Process(); \ } @@ -231,17 +231,17 @@ BGMV_SHRINK_TYPE_DECLARE(half) #endif namespace vllm_ascend { -extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, void* indices, +extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, void* indices, uint32_t indicesSize, void* y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim, uint32_t maxLoRARank, float scale) { uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore; if (type == AscendType::FP16) { - bgmv_shrink_half<<>>(x, weight, indices, y, batchSize, numTokensPerCore, + bgmv_shrink_half<<>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); } else if (type == AscendType::BF16) { #if (__CCE_AICORE__ >= 220) - bgmv_shrink_bfloat16_t<<>>(x, weight, indices, y, batchSize, numTokensPerCore, + bgmv_shrink_bfloat16_t<<>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); #endif } else { diff --git a/csrc/kernels/sgmv_expand.cpp b/csrc/kernels/sgmv_expand.cpp index 95e7bb8..5466bd6 100644 --- a/csrc/kernels/sgmv_expand.cpp +++ b/csrc/kernels/sgmv_expand.cpp @@ -53,8 +53,8 @@ public: public: __aicore__ inline SGMVExpand(AscendC::TPipe* pipe) : pipe_(pipe) {} - __aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* loraIndices, - __gm__ void* seqLen, __gm__ void* yIn, __gm__ void* yOut, + __aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* loraIndices, uint32_t loraIndicesSize, + __gm__ void* seqLen, uint32_t seqLenSize, __gm__ void* yIn, __gm__ void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank, uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim) { @@ -70,8 +70,8 @@ public: wGm_.SetGlobalBuffer((__gm__ W_T *)weight); yInGm_.SetGlobalBuffer((__gm__ Y_T *)yIn); yOutGm_.SetGlobalBuffer((__gm__ Y_T *)yOut); - loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices); - seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen); + loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices, loraIndicesSize); + seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen, seqLenSize); pipe_->InitBuffer(inQueueX_, 1, NUM_ELEMENTS_PER_REPEAT * sizeof(X_T)); pipe_->InitBuffer(inQueueW_, BUFFER_NUM, W_IN_TILE_NUM_ELEMENTS * sizeof(W_T)); @@ -340,7 +340,8 @@ private: #define SGMV_EXPAND_TYPE_DECLARE(TYPE) \ extern "C" __global__ __aicore__ void sgmv_expand_##TYPE(__gm__ void* x, __gm__ void* weight, \ - __gm__ void* loraIndices, __gm__ void* seqLen, \ + __gm__ void* loraIndices, uint32_t loraIndicesSize, \ + __gm__ void* seqLen, uint32_t seqLenSize, \ __gm__ void* yIn, __gm__ void* yOut, \ uint32_t batchSize, uint32_t numTokensPerCore, \ uint32_t maxLoRARank, uint32_t outputHiddenDim, \ @@ -348,7 +349,8 @@ private: { \ AscendC::TPipe pipe; \ SGMVExpand op(&pipe); \ - op.Init(x, weight, loraIndices, seqLen, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \ + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, \ + yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \ outputHiddenDim, sliceOffset, outputFullDim); \ op.Process(); \ } @@ -360,18 +362,22 @@ SGMV_EXPAND_TYPE_DECLARE(half) #endif namespace vllm_ascend { -extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, void* loraIndices, void* seqLen, +extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, + void* loraIndices, uint32_t loraIndicesSize, + void* seqLen, uint32_t seqLenSize, void* yIn, void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank, uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim) { uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore; if (type == AscendType::FP16) { - sgmv_expand_half<<>>(x, weight, loraIndices, seqLen, yIn, yOut, batchSize, + sgmv_expand_half<<>>(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, + yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); } else if (type == AscendType::BF16) { #if (__CCE_AICORE__ >= 220) - sgmv_expand_bfloat16_t<<>>(x, weight, loraIndices, seqLen, yIn, yOut, batchSize, + sgmv_expand_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, + seqLen, seqLenSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); #endif diff --git a/csrc/kernels/sgmv_shrink.cpp b/csrc/kernels/sgmv_shrink.cpp index 357fb68..a72e592 100644 --- a/csrc/kernels/sgmv_shrink.cpp +++ b/csrc/kernels/sgmv_shrink.cpp @@ -29,7 +29,8 @@ public: public: __aicore__ inline SGMVShrink(AscendC::TPipe *pipe) : pipe_(pipe) {} - __aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *loraIndices, __gm__ void *seqLen, + __aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *loraIndices, uint32_t loraIndicesSize, + __gm__ void *seqLen, uint32_t seqLenSize, __gm__ void *y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim, uint32_t maxLoRARank, float scale) { @@ -44,8 +45,8 @@ public: xGm_.SetGlobalBuffer((__gm__ X_T *)x); yOutGm_.SetGlobalBuffer((__gm__ Y_T *)y); wGm_.SetGlobalBuffer((__gm__ W_T *)weight); - loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices); - seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen); + loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices, loraIndicesSize); + seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen, seqLenSize); pipe_->InitBuffer(inQueueX_, BUFFER_NUM, TILE_LENGTH * sizeof(X_T)); pipe_->InitBuffer(inQueueW_, BUFFER_NUM, TILE_LENGTH * sizeof(W_T)); @@ -226,14 +227,16 @@ private: #define SGMV_SHRINK_TYPE_DECLARE(TYPE) \ extern "C" __global__ __aicore__ void sgmv_shrink_##TYPE(__gm__ void* x, __gm__ void* weight, \ - __gm__ void* loraIndices, __gm__ void* seqLen, \ + __gm__ void* loraIndices, uint32_t loraIndicesSize, \ + __gm__ void* seqLen, uint32_t seqLenSize, \ __gm__ void* y, uint32_t batchSize, \ uint32_t numTokensPerCore, uint32_t inputHiddenDim, \ uint32_t maxLoRARank, float scale) \ { \ AscendC::TPipe pipe; \ SGMVShrink op(&pipe); \ - op.Init(x, weight, loraIndices, seqLen,y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \ + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, \ + y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \ op.Process(); \ } @@ -244,18 +247,23 @@ SGMV_SHRINK_TYPE_DECLARE(half) #endif namespace vllm_ascend { -extern void sgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, void* loraIndices, void* seqLen, +extern void sgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, + void* loraIndices, uint32_t loraIndicesSize, + void* seqLen, uint32_t seqLenSize, void* y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim, uint32_t maxLoRARank, float scale) { uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore; if (type == AscendType::FP16) { - sgmv_shrink_half<<>>(x, weight, loraIndices, seqLen, y, batchSize, + sgmv_shrink_half<<>>(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, + y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); } else if (type == AscendType::BF16) { #if (__CCE_AICORE__ >= 220) - sgmv_shrink_bfloat16_t<<>>(x, weight, loraIndices, seqLen, y, batchSize, + sgmv_shrink_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, + seqLen, seqLenSize, + y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); #endif diff --git a/csrc/ops.h b/csrc/ops.h index 40a9fae..4773992 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -67,6 +67,7 @@ namespace vllm_ascend { void *x, void *weight, void *indices, + uint32_t indicesSize, void *y, uint32_t batch_size, uint32_t num_tokens_per_core, @@ -80,6 +81,7 @@ namespace vllm_ascend { void *x, void *weight, void *indices, + uint32_t indicesSize, void *y, void *y_out, uint32_t batch_size, @@ -95,7 +97,9 @@ namespace vllm_ascend { void *x, void *weight, void *loraIndices, + uint32_t loraIndicesSize, void *seqLen, + uint32_t seqLenSize, void *y, uint32_t batch_size, uint32_t num_tokens_per_core, @@ -109,7 +113,9 @@ namespace vllm_ascend { void *x, void *weight, void *loraIndices, + uint32_t loraIndicesSize, void *seqLen, + uint32_t seqLenSize, void *y, void *y_out, uint32_t batch_size, diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 82cea77..375ef59 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -226,6 +226,7 @@ void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Ten void* x_ptr = x.data_ptr(); void* weight_ptr = weight.data_ptr(); void* indices_ptr = indices.data_ptr(); + int indices_size = indices.size(0); void* y_ptr = y.data_ptr(); int batch_size = x.size(0); int input_hidden_token = x.size(1); @@ -234,7 +235,7 @@ void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Ten aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("bgmv_shrink"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, input_hidden_token, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, batch_size, input_hidden_token, lora_rank, scale_f]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -242,7 +243,7 @@ void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Ten TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); - bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, num_tokens_per_core, + bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, batch_size, num_tokens_per_core, input_hidden_token, lora_rank, scale_f); return 0; }); @@ -271,6 +272,7 @@ at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, a void* x_ptr = x.data_ptr(); void* weight_ptr = weight.data_ptr(); void* indices_ptr = indices.data_ptr(); + int indices_size = indices.size(0); void* y_ptr = y.data_ptr(); void* y_out_ptr = y_out.data_ptr(); int batch_size = x.size(0); @@ -279,7 +281,7 @@ at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, a aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("bgmv_expand"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size, lora_rank, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, y_out_ptr, batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -287,7 +289,7 @@ at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, a TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); - bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size, + bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, y_out_ptr, batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim); return 0; }); @@ -309,6 +311,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at void* weight_ptr = weight.data_ptr(); void* lora_indices_ptr = lora_indices.data_ptr(); void* seq_len_ptr = seq_len.data_ptr(); + int lora_indices_size = lora_indices.size(0); + int seq_len_size = seq_len.size(0); void* y_ptr = y.data_ptr(); int batch_size = x.size(0); int input_hidden_token = x.size(1); @@ -317,7 +321,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("sgmv_shrink"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, + seq_len_ptr, seq_len_size, y_ptr, batch_size, input_hidden_token, lora_rank, scale_f]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -325,7 +330,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); - sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, batch_size, + sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, + y_ptr, batch_size, num_tokens_per_core, input_hidden_token, lora_rank, scale_f); return 0; }); @@ -352,6 +358,8 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic void* weight_ptr = weight.data_ptr(); void* lora_indices_ptr = lora_indices.data_ptr(); void* seq_len_ptr = seq_len.data_ptr(); + int lora_indices_size = lora_indices.size(0); + int seq_len_size = seq_len.size(0); void* y_ptr = y.data_ptr(); void* y_out_ptr = y_out.data_ptr(); int batch_size = x.size(0); @@ -360,7 +368,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("sgmv_expand"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, y_out_ptr, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -368,7 +376,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); - sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, y_out_ptr, + sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim); return 0; }); diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index deb8fe7..8d206bf 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -18,4 +18,5 @@ import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_linear # noqa import vllm_ascend.patch.worker.patch_common.patch_logits # noqa +import vllm_ascend.patch.worker.patch_common.patch_lora_embedding # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py b/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py new file mode 100644 index 0000000..02d5804 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py @@ -0,0 +1,29 @@ +from typing import Optional + +import vllm +from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig +from vllm.lora.layers import VocabParallelEmbeddingWithLoRA +from vllm.lora.utils import _all_lora_classes + +from vllm_ascend.ops.vocab_parallel_embedding import \ + AscendVocabParallelEmbedding + + +class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendVocabParallelEmbedding + + +# Patch for lora register_model issue after overriding VocabParallelEmbedding class (#2515) +_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) +vllm.lora.utils._all_lora_classes = _all_lora_classes