diff --git a/.github/workflows/release_whl.yml b/.github/workflows/release_whl.yml index d23d427e..741ef3d3 100644 --- a/.github/workflows/release_whl.yml +++ b/.github/workflows/release_whl.yml @@ -98,6 +98,7 @@ jobs: --exclude libc_sec.so \ --exclude "libascend*.so" \ --exclude "libtorch*.so" \ + --exclude "libopapi.so" \ --exclude "liberror_manager.so" done rm -f dist/*.whl diff --git a/CMakeLists.txt b/CMakeLists.txt index 272bdb13..3e810fa8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,7 +63,8 @@ ascendc_library(vllm_ascend_kernels SHARED message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") file(GLOB VLLM_ASCEND_SRC -${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp) +${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp +${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp) include_directories( ${pybind11_INCLUDE_DIRS} @@ -88,6 +89,7 @@ pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC}) target_link_directories( vllm_ascend_C PRIVATE + ${TORCH_LIBRARY_DIRS} ${TORCH_NPU_PATH}/lib/ ${ASCEND_HOME_PATH}/lib64 ) @@ -96,7 +98,7 @@ target_link_libraries( vllm_ascend_C PUBLIC ${TORCH_LIBRARIES} - libtorch_npu.so + torch_npu vllm_ascend_kernels ascendcl tiling_api @@ -104,6 +106,7 @@ target_link_libraries( platform ascendalog dl + opapi ) target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib") diff --git a/csrc/aclnn_torch_adapter/NPUBridge.cpp b/csrc/aclnn_torch_adapter/NPUBridge.cpp new file mode 100644 index 00000000..dc335cb4 --- /dev/null +++ b/csrc/aclnn_torch_adapter/NPUBridge.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2020, Huawei Technologies Co., Ltd +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "NPUBridge.h" + +namespace vllm_ascend +{ + NPUStorageImpl *NPUBridge::GetNpuStorageImpl(c10::StorageImpl *storageImpl) + { + return static_cast(storageImpl); + } + + NPUStorageImpl *NPUBridge::GetNpuStorageImpl(c10::Storage &&storage) + { + return static_cast(storage.unsafeGetStorageImpl()); + } + + NPUStorageImpl *NPUBridge::GetNpuStorageImpl(const at::Tensor &tensor) + { + return static_cast(tensor.storage().unsafeGetStorageImpl()); + } + + NPUStorageDesc &NPUBridge::GetNpuStorageImplDesc(const at::Tensor &tensor) + { + return static_cast(tensor.storage().unsafeGetStorageImpl())->npu_desc_; + } +} diff --git a/csrc/aclnn_torch_adapter/NPUBridge.h b/csrc/aclnn_torch_adapter/NPUBridge.h new file mode 100644 index 00000000..e93a1048 --- /dev/null +++ b/csrc/aclnn_torch_adapter/NPUBridge.h @@ -0,0 +1,29 @@ +// Copyright (c) 2020, Huawei Technologies Co., Ltd +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include "NPUStorageImpl.h" + +namespace vllm_ascend +{ + + class NPUBridge + { + public: + // at::tensor to NPUStorageImpl + static NPUStorageImpl *GetNpuStorageImpl(const at::Tensor &tensor); + + // c10::StorageImpl to NPUStorageImpl + static NPUStorageImpl *GetNpuStorageImpl(c10::StorageImpl *storageImpl); + + // c10::Storage to NPUStorageImpl + static NPUStorageImpl *GetNpuStorageImpl(c10::Storage &&storage); + + // tensor to NPUStorageDesc + static NPUStorageDesc &GetNpuStorageImplDesc(const at::Tensor &tensor); + }; +} diff --git a/csrc/aclnn_torch_adapter/NPUStorageImpl.cpp b/csrc/aclnn_torch_adapter/NPUStorageImpl.cpp new file mode 100644 index 00000000..9dfe0c0c --- /dev/null +++ b/csrc/aclnn_torch_adapter/NPUStorageImpl.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2020, Huawei Technologies Co., Ltd +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "NPUStorageImpl.h" + +namespace vllm_ascend +{ + + NPUStorageImpl::NPUStorageImpl( + use_byte_size_t use_byte_size, + size_t size_bytes, + at::DataPtr data_ptr, + at::Allocator *allocator, + bool resizable) : c10::StorageImpl(use_byte_size, + size_bytes, + at::DataPtr(std::move(data_ptr)), + allocator, + resizable) + { + } + + void NPUStorageImpl::release_resources() + { + StorageImpl::release_resources(); + } + + c10::intrusive_ptr make_npu_storage_impl( + c10::StorageImpl::use_byte_size_t, + c10::SymInt size_bytes, + c10::DataPtr data_ptr, + c10::Allocator *allocator, + bool resizable) + { + if (data_ptr == nullptr) + { + data_ptr = allocator->allocate(size_bytes.as_int_unchecked()); + } + // Correctly create NPUStorageImpl object. + c10::intrusive_ptr npu_storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + size_bytes.as_int_unchecked(), + std::move(data_ptr), + allocator, + resizable); + // There is no need to consider the NPUStorageDesc information, it will be carried out in the subsequent processing. + return npu_storage_impl; + } + +} diff --git a/csrc/aclnn_torch_adapter/NPUStorageImpl.h b/csrc/aclnn_torch_adapter/NPUStorageImpl.h new file mode 100644 index 00000000..fcf293b1 --- /dev/null +++ b/csrc/aclnn_torch_adapter/NPUStorageImpl.h @@ -0,0 +1,67 @@ +// Copyright (c) 2020, Huawei Technologies Co., Ltd +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "acl/acl_rt.h" +#include "acl/acl_base.h" + +namespace vllm_ascend +{ + + struct NPUStorageDesc + { + public: + struct use_byte_size_t + { + }; + + c10::SmallVector base_sizes_; + c10::SmallVector base_strides_; + c10::SmallVector storage_sizes_; + int64_t base_offset_ = 0; + use_byte_size_t base_dtype_ = {}; + aclFormat origin_format_ = ACL_FORMAT_UNDEFINED; + aclFormat npu_format_ = ACL_FORMAT_ND; + // used to make CANN GE tensor from storagImpl + caffe2::TypeMeta data_type_ = caffe2::TypeMeta::Make(); + }; + + struct NPUStorageImpl : public c10::StorageImpl + { + explicit NPUStorageImpl( + use_byte_size_t use_byte_size, + size_t size_bytes, + at::DataPtr data_ptr, + at::Allocator *allocator, + bool resizable); + ~NPUStorageImpl() override = default; + + void release_resources() override; + + NPUStorageDesc npu_desc_; + + NPUStorageDesc get_npu_desc() const + { + return npu_desc_; + } + }; + + c10::intrusive_ptr make_npu_storage_impl( + c10::StorageImpl::use_byte_size_t, + c10::SymInt size_bytes, + c10::DataPtr data_ptr, + c10::Allocator *allocator, + bool resizable); + +} diff --git a/csrc/aclnn_torch_adapter/op_api_common.h b/csrc/aclnn_torch_adapter/op_api_common.h new file mode 100644 index 00000000..e4c8a517 --- /dev/null +++ b/csrc/aclnn_torch_adapter/op_api_common.h @@ -0,0 +1,591 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OP_API_COMMON_ADAPTER +#define OP_API_COMMON_ADAPTER + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/interface/EnvVariables.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "NPUBridge.h" +#include "NPUStorageImpl.h" + +#define NPU_NAME_SPACE at_npu::native +using namespace at; + +typedef struct aclOpExecutor aclOpExecutor; +typedef struct aclTensor aclTensor; +typedef struct aclScalar aclScalar; +typedef struct aclIntArray aclIntArray; +typedef struct aclFloatArray aclFloatArray; +typedef struct aclBoolArray aclBoolArray; +typedef struct aclTensorList aclTensorList; + +typedef aclTensor *(*_aclCreateTensor)( + const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type, + const int64_t *stride, int64_t offset, aclFormat format, + const int64_t *storage_dims, uint64_t storage_dims_num, void *tensor_data); +typedef aclScalar *(*_aclCreateScalar)(void *value, aclDataType data_type); +typedef aclIntArray *(*_aclCreateIntArray)(const int64_t *value, uint64_t size); +typedef aclFloatArray *(*_aclCreateFloatArray)(const float *value, + uint64_t size); +typedef aclBoolArray *(*_aclCreateBoolArray)(const bool *value, uint64_t size); +typedef aclTensorList *(*_aclCreateTensorList)(const aclTensor *const *value, + uint64_t size); + +typedef int (*_aclDestroyTensor)(const aclTensor *tensor); +typedef int (*_aclDestroyScalar)(const aclScalar *scalar); +typedef int (*_aclDestroyIntArray)(const aclIntArray *array); +typedef int (*_aclDestroyFloatArray)(const aclFloatArray *array); +typedef int (*_aclDestroyBoolArray)(const aclBoolArray *array); +typedef int (*_aclDestroyTensorList)(const aclTensorList *array); + +constexpr int kHashBufSize = 8192; +constexpr int kHashBufMaxSize = kHashBufSize + 1024; +extern thread_local char g_hashBuf[kHashBufSize]; +extern thread_local int g_hashOffset; + +#ifdef MMCV_WITH_XLA +#define DEVICE_TYPE at_npu::key::NativeDeviceType +#else +#define DEVICE_TYPE c10::DeviceType::PrivateUse1 +#endif + +#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ + _(at::ScalarType::Byte, ACL_UINT8) \ + _(at::ScalarType::Char, ACL_INT8) \ + _(at::ScalarType::Short, ACL_INT16) \ + _(at::ScalarType::Int, ACL_INT32) \ + _(at::ScalarType::Long, ACL_INT64) \ + _(at::ScalarType::Half, ACL_FLOAT16) \ + _(at::ScalarType::Float, ACL_FLOAT) \ + _(at::ScalarType::Double, ACL_DOUBLE) \ + _(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \ + _(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \ + _(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \ + _(at::ScalarType::Bool, ACL_BOOL) \ + _(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ + _(at::ScalarType::BFloat16, ACL_BF16) \ + _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ + _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) + +constexpr aclDataType kATenScalarTypeToAclDataTypeTable + [static_cast(at::ScalarType::NumOptions) + 1] = { +#define DEFINE_ENUM(_1, n) n, + AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) +#undef DEFINE_ENUM +}; + +#define GET_OP_API_FUNC(apiName) \ + reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName)) + +#define MEMCPY_TO_BUF(data_expression, size_expression) \ + if (g_hashOffset + (size_expression) > kHashBufSize) { \ + g_hashOffset = kHashBufMaxSize; \ + return; \ + } \ + memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \ + g_hashOffset += size_expression; + +bool IsOpInputBaseFormat(const at::Tensor &tensor) +{ + if (!tensor.is_privateuseone()) { + return true; + } + const auto format = vllm_ascend::NPUBridge::GetNpuStorageImplDesc(tensor).npu_format_; + return (format == ACL_FORMAT_ND) || (format == ACL_FORMAT_NCHW) || (format == ACL_FORMAT_NHWC) || + (format == ACL_FORMAT_NCDHW); +} + +inline const char *GetOpApiLibName(void) { return "libopapi.so"; } + +inline const char *GetCustOpApiLibName(void) { return "libcust_opapi.so"; } + +inline void *GetOpApiFuncAddrInLib(void *handler, const char *libName, + const char *apiName) { + auto funcAddr = dlsym(handler, apiName); + return funcAddr; +} + +inline void *GetOpApiLibHandler(const char *libName) { + auto handler = dlopen(libName, RTLD_LAZY); + return handler; +} + +inline void *GetOpApiFuncAddr(const char *apiName) { + static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName()); + if (custOpApiHandler != nullptr) { + auto funcAddr = + GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + + static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName()); + if (opApiHandler == nullptr) { + return nullptr; + } + return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName); +} + +inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) { + c10::Scalar expScalar; + const at::Tensor *aclInput = &tensor; + if (aclInput->scalar_type() == at::ScalarType::Double) { + double value = *(double *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Long) { + int64_t value = *(int64_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Float) { + float value = *(float *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Int) { + int value = *(int *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Half) { + c10::Half value = *(c10::Half *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Bool) { + int8_t value = *(int8_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexDouble) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexFloat) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::BFloat16) { + c10::BFloat16 value = *(c10::BFloat16 *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } + return expScalar; +} + +inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) { + at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); + int deviceIndex = 0; + return cpuPinMemTensor.to(c10::Device(DEVICE_TYPE, deviceIndex), + cpuPinMemTensor.scalar_type(), true, true); +} + +inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar, + at::ScalarType scalar_data_type) { + return CopyTensorHostToDevice( + scalar_to_tensor(cpu_scalar).to(scalar_data_type)); +} + +inline aclTensor *ConvertType(const at::Tensor &at_tensor) { + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + if (aclCreateTensor == nullptr) { + return nullptr; + } + + if (!at_tensor.defined()) { + return nullptr; + } + at::ScalarType scalar_data_type = at_tensor.scalar_type(); + aclDataType acl_data_type = + kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK( + acl_data_type != ACL_DT_UNDEFINED, + std::string(c10::toString(scalar_data_type)) + " has not been supported") + c10::SmallVector storageDims; + // if acl_data_type is ACL_STRING, storageDims is empty. + auto itemsize = at_tensor.itemsize(); + TORCH_CHECK(itemsize != 0, "When ConvertType, tensor item size cannot be zero."); + + const auto dimNum = at_tensor.sizes().size(); + aclFormat format = ACL_FORMAT_ND; + if (!IsOpInputBaseFormat(at_tensor)) { + format = vllm_ascend::NPUBridge::GetNpuStorageImpl(at_tensor)->npu_desc_.npu_format_; + if (acl_data_type != ACL_STRING) { + storageDims = vllm_ascend::NPUBridge::GetNpuStorageImpl(at_tensor)->npu_desc_.storage_sizes_; + } + } else { + switch (dimNum) { + case 3: + format = ACL_FORMAT_NCL; + break; + case 4: + format = ACL_FORMAT_NCHW; + break; + case 5: + format = ACL_FORMAT_NCDHW; + break; + default: + format = ACL_FORMAT_ND; + } + if (acl_data_type != ACL_STRING) { + storageDims.push_back(at_tensor.storage().nbytes() / itemsize); + } + } + + if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + c10::Scalar expScalar = ConvertTensorToScalar(at_tensor); + at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type); + return aclCreateTensor(aclInput.sizes().data(), aclInput.sizes().size(), + acl_data_type, aclInput.strides().data(), + aclInput.storage_offset(), format, + storageDims.data(), storageDims.size(), + const_cast(aclInput.storage().data())); + } + + auto acl_tensor = aclCreateTensor( + at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type, + at_tensor.strides().data(), at_tensor.storage_offset(), format, + storageDims.data(), storageDims.size(), + const_cast(at_tensor.storage().data())); + return acl_tensor; +} + +inline aclScalar *ConvertType(const at::Scalar &at_scalar) { + static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar); + if (aclCreateScalar == nullptr) { + return nullptr; + } + + at::ScalarType scalar_data_type = at_scalar.type(); + aclDataType acl_data_type = + kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK( + acl_data_type != ACL_DT_UNDEFINED, + std::string(c10::toString(scalar_data_type)) + " has not been supported") + aclScalar *acl_scalar = nullptr; + switch (scalar_data_type) { + case at::ScalarType::Double: { + double value = at_scalar.toDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Long: { + int64_t value = at_scalar.toLong(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Bool: { + bool value = at_scalar.toBool(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::ComplexDouble: { + auto value = at_scalar.toComplexDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + default: + acl_scalar = nullptr; + break; + } + return acl_scalar; +} + +inline aclIntArray *ConvertType(const at::IntArrayRef &at_array) { + static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray); + if (aclCreateIntArray == nullptr) { + return nullptr; + } + auto array = aclCreateIntArray(at_array.data(), at_array.size()); + return array; +} + +template +inline aclBoolArray *ConvertType(const std::array &value) { + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclBoolArray *ConvertType(const at::ArrayRef &value) { + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclTensorList *ConvertType(const at::TensorList &at_tensor_list) { + static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList); + if (aclCreateTensorList == nullptr) { + return nullptr; + } + + std::vector tensor_list(at_tensor_list.size()); + for (size_t i = 0; i < at_tensor_list.size(); i++) { + tensor_list[i] = ConvertType(at_tensor_list[i]); + } + auto acl_tensor_list = + aclCreateTensorList(tensor_list.data(), tensor_list.size()); + return acl_tensor_list; +} + +inline aclTensor *ConvertType(const c10::optional &opt_tensor) { + if (opt_tensor.has_value() && opt_tensor.value().defined()) { + return ConvertType(opt_tensor.value()); + } + return nullptr; +} + +inline aclIntArray *ConvertType( + const c10::optional &opt_array) { + if (opt_array.has_value()) { + return ConvertType(opt_array.value()); + } + return nullptr; +} + +inline aclScalar *ConvertType(const c10::optional &opt_scalar) { + if (opt_scalar.has_value()) { + return ConvertType(opt_scalar.value()); + } + return nullptr; +} + +inline aclDataType ConvertType(const at::ScalarType scalarType) { + return kATenScalarTypeToAclDataTypeTable[static_cast(scalarType)]; +} + +template +T ConvertType(T value) { + return value; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr, + std::index_sequence) { + typedef int (*OpApiFunc)( + typename std::decay(params))>::type...); + auto func = reinterpret_cast(opApiAddr); + return func; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) { + static constexpr auto size = std::tuple_size::value; + return ConvertToOpApiFunc(params, opApiAddr, + std::make_index_sequence{}); +} + +inline void Release(aclTensor *p) { + static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); + if (aclDestroyTensor == nullptr) { + return; + } + aclDestroyTensor(p); +} + +inline void Release(aclScalar *p) { + static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar); + if (aclDestroyScalar == nullptr) { + return; + } + aclDestroyScalar(p); +} + +inline void Release(aclIntArray *p) { + static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray); + if (aclDestroyIntArray == nullptr) { + return; + } + + aclDestroyIntArray(p); +} + +inline void Release(aclBoolArray *p) { + static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray); + if (aclDestroyBoolArray == nullptr) { + return; + } + + aclDestroyBoolArray(p); +} + +inline void Release(aclTensorList *p) { + static const auto aclDestroyTensorList = + GET_OP_API_FUNC(aclDestroyTensorList); + if (aclDestroyTensorList == nullptr) { + return; + } + + aclDestroyTensorList(p); +} + +template +void Release(T value) { + (void)value; +} + +template +void CallRelease(Tuple t, std::index_sequence) { + (void)std::initializer_list{(Release(std::get(t)), 0)...}; +} + +template +void ReleaseConvertTypes(Tuple &t) { + static constexpr auto size = std::tuple_size::value; + CallRelease(t, std::make_index_sequence{}); +} + +template +constexpr auto ConvertTypes(Ts &... args) { + return std::make_tuple(ConvertType(args)...); +} + +template +auto call(Function f, Tuple t, std::index_sequence) { + return f(std::get(t)...); +} + +template +auto call(Function f, Tuple t) { + static constexpr auto size = std::tuple_size::value; + return call(f, t, std::make_index_sequence{}); +} + +template +void AddParamToBuf(const std::array &value) { + MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool)); +} + +template +void AddParamToBuf(const T &value) { + MEMCPY_TO_BUF(&value, sizeof(T)); +} + +void AddParamToBuf(const at::Tensor &); +void AddParamToBuf(const at::Scalar &); +void AddParamToBuf(const at::IntArrayRef &); +void AddParamToBuf(const at::ArrayRef &); +void AddParamToBuf(const at::TensorList &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const at::ScalarType); +void AddParamToBuf(const string &); +void AddParamToBuf(); + +template +void AddParamToBuf(const T &arg, Args &... args) { + AddParamToBuf(arg); + AddParamToBuf(args...); +} + +uint64_t CalcHashId(); +typedef int (*InitHugeMemThreadLocal)(void *, bool); +typedef void (*UnInitHugeMemThreadLocal)(void *, bool); +typedef void (*ReleaseHugeMem)(void *, bool); + +#define EXEC_NPU_CMD(aclnn_api, ...) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = \ + GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + static const auto initMemAddr = \ + GetOpApiFuncAddr("InitHugeMemThreadLocal"); \ + static const auto unInitMemAddr = \ + GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \ + static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \ + TORCH_CHECK( \ + getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \ + #aclnn_api, " or ", #aclnn_api "GetWorkspaceSize", " not in ", \ + GetOpApiLibName(), ", or ", GetOpApiLibName(), "not found."); \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + uint64_t workspace_size = 0; \ + uint64_t *workspace_size_addr = &workspace_size; \ + aclOpExecutor *executor = nullptr; \ + aclOpExecutor **executor_addr = &executor; \ + InitHugeMemThreadLocal initMemFunc = \ + reinterpret_cast(initMemAddr); \ + UnInitHugeMemThreadLocal unInitMemFunc = \ + reinterpret_cast(unInitMemAddr); \ + if (initMemFunc) { \ + initMemFunc(nullptr, false); \ + } \ + auto converted_params = \ + ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ + static auto getWorkspaceSizeFunc = \ + ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + TORCH_CHECK(workspace_status == 0, \ + "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + void *workspace_addr = nullptr; \ + if (workspace_size != 0) { \ + at::TensorOptions options = \ + at::TensorOptions(torch_npu::utils::get_npu_device_type()); \ + auto workspace_tensor = \ + at::empty({workspace_size}, options.dtype(kByte)); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ + } \ + auto acl_call = [converted_params, workspace_addr, workspace_size, \ + acl_stream, executor]() -> int { \ + typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, \ + const aclrtStream); \ + OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ + auto api_ret = \ + opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ + TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", \ + aclGetRecentErrMsg()); \ + ReleaseConvertTypes(converted_params); \ + ReleaseHugeMem releaseMemFunc = \ + reinterpret_cast(releaseMemAddr); \ + if (releaseMemFunc) { \ + releaseMemFunc(nullptr, false); \ + } \ + return api_ret; \ + }; \ + at_npu::native::OpCommand cmd; \ + cmd.Name(#aclnn_api); \ + cmd.SetCustomHandler(acl_call); \ + cmd.Run(); \ + if (unInitMemFunc) { \ + unInitMemFunc(nullptr, false); \ + } \ + } while (false) + +#endif diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 90e7f03a..d2a1f90e 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -27,12 +27,14 @@ #include "ops.h" #include "utils.h" #include "mla_preprocess/op_host/mla_preprocess.h" +#include "aclnn_torch_adapter/op_api_common.h" #include #include #include namespace vllm_ascend { +const int64_t INT4_NUMS_IN_INT32 = 8; void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping, aclrtStream stream) { torch::Device src_device = src.device(); @@ -520,6 +522,36 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic cmd.Run(); return y_out; } + +std::tuple grouped_matmul_swiglu_quant( + const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale, + const at::Tensor &group_list, const c10::optional &bias, const c10::optional &offset) +{ + int m = x.sizes()[0]; + int n = weight.sizes()[2]; + bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt; + if (is_a8w4) { + n *= INT4_NUMS_IN_INT32; + } + + at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char)); + at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float)); + at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float)); + + EXEC_NPU_CMD( + aclnnGroupedMatmulSwigluQuantWeightNZ, + x, + weight, + bias, + offset, + weight_scale, + x_scale, + group_list, + output, + output_scale, + output_offset); + return std::tuple(output, output_scale, output_offset); +} } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -576,4 +608,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()"); ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks); + + ops.def( + "grouped_matmul_swiglu_quant(Tensor x, Tensor weight, Tensor weight_scale, Tensor x_scale," + " Tensor group_list, *, Tensor? bias=None," + " Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)"); + ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index dbb056be..e3b35b10 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -35,7 +35,7 @@ namespace vllm_ascend { namespace meta { - +const int64_t INT4_NUMS_IN_INT32 = 8; std::tuple rotary_embedding_meta( at::Tensor &positions, at::Tensor &query, @@ -114,6 +114,22 @@ std::tuple mla_preproces return {q_out0, kv_cache_out0, q_out1, kv_cache_out1}; } +std::tuple grouped_matmul_swiglu_quant( + const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale, + const at::Tensor &group_list, const c10::optional &bias, const c10::optional &offset) +{ + int m = x.sizes()[0]; + int n = weight.sizes()[2]; + bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt; + if (is_a8w4) { + n *= INT4_NUMS_IN_INT32; + } + at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char)); + at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float)); + at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float)); + return {output, output_scale, output_offset}; +} + } // namespace meta } // namespace vllm_ascend @@ -132,5 +148,7 @@ namespace { ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); // MLA preprocess ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); + // grouped_matmul_swiglu_quant meta implementation + ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant); } } diff --git a/tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py b/tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py new file mode 100644 index 00000000..28e724bb --- /dev/null +++ b/tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py @@ -0,0 +1,175 @@ +import gc + +import numpy as np +import torch +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + + +def x_int8_to_x_int4(x: torch.Tensor): + m, k = x.shape + x_high_4bit = torch.floor(x.to(torch.float16) // 16).to(torch.int8) + x_low_4bit = ( + torch.bitwise_and(x.view(torch.int16), 0x0f0f).view(torch.int8) - 8) + x_int4 = torch.empty((2 * m, k), dtype=torch.int8) + x_int4[::2, :] = x_high_4bit + x_int4[1::2, :] = x_low_4bit + return x_int4 + + +def custom_mm(x: torch.Tensor, weight: torch.Tensor, + weight_scale: torch.Tensor, m: int): + """ + Performing Quantized GMM (General Matrix Multiplication) Operation + Parameters: + x (torch.Tensor): Input tensor with shape (m, k). + weight (torch.Tensor): Weight tensor with shape (k, n). + weight_scale (torch.Tensor): Scaling factor for each channel. + - In perGroup scenario: Shape is (k_group_num, n). Note: When k_group_num == 1, it is a perChannel scenario. + - In perChannel scenario: Shape is (n). + m (int): Number of tokens (number of rows in x). + Returns: + mm_out(fp16): Result of MatMul + perGroup or perChannel dequantization. + """ + # Perform matrix multiplication with int32 precision + k, n = weight.shape + mm_out = torch.zeros((m, n), dtype=torch.float16) + # perGroup scenario + if len(weight_scale.shape) == 2 and weight_scale.shape[0] != 1: + k_group = weight_scale.shape[0] + per_group_ele = k // k_group + x_grouped = x.view(-1, k_group, per_group_ele).transpose(0, 1) + weight_grouped = weight.view(k_group, per_group_ele, n) + + c_temp = torch.bmm(x_grouped.to(torch.int32), + weight_grouped.to(torch.int32)).to(torch.float16) + for k_idx in range(k_group): + mm_out += (c_temp[k_idx] * + weight_scale[k_idx].view(1, -1).to(torch.float16)).to( + torch.float16) + # perChannel scenario + elif len(weight_scale.shape) == 1 or (len(weight_scale.shape) == 2 + and weight_scale.shape[0] == 1): + c_temp = torch.matmul(x.to(torch.int32), + weight.to(torch.int32)).to(torch.float32) + mm_out = c_temp * weight_scale.view(1, -1).to(torch.float16) + return mm_out.to(torch.float32) + + +def gmm_swiglu_quant_golden_a8_w4(x: torch.Tensor, weight: torch.Tensor, + weight_scale: torch.Tensor, + per_token_scale: torch.Tensor, + bias: torch.Tensor, + group_list: torch.Tensor): + """ + Process the input data by group and call the GMM_Swiglu_quant function for quantization computation. + Parameters: + x (torch.Tensor): Input tensor with shape (M, K), type INT8. + weight (torch.Tensor): List of weight tensors, each with shape (E, K, N), data type INT8 but data range INT4, representing INT4 values. + weight_scale (torch.Tensor): Scaling factor for each channel. + - In perGroup scenario: shape (E, k_group_num, N). + - In perChannel scenario: shape (E, N). + per_token_scale (torch.Tensor): Scaling factor for each token, shape (M, ). + bias: torch.Tensor, + group_list (list): List defining the number of tokens in each group. + Returns: + quant_output (torch.Tensor): Quantized output tensor with shape (M, N // 2). + quant_scale_output (torch.Tensor): Quantization scaling factor, shape (M, ). + """ + M, N = x.shape[0], weight.shape[2] + quant_output = torch.zeros(M, N // 2).to(torch.int8) + quant_scale_output = torch.zeros(M).to(torch.float32) + # Preprocessing X_INT8 -> X_INT4 + x_int4 = x_int8_to_x_int4(x) + start_idx = 0 + # Number of tokens in the previous group + pre_v = 0 + group_list = group_list.tolist() + # Traverse group_list and process data by group + for i, v in enumerate(group_list): + curr_v = v + # Calculate the number of tokens in the current group " * 2 " because 1 row of Int8--> 2 rows of Int4 + temp_v = int((curr_v - pre_v) * 2) + # Update the number of tokens in the previous group + pre_v = curr_v + if (temp_v > 0): + mm_out = custom_mm(x_int4[int(start_idx):int(start_idx + temp_v)], + weight[i], weight_scale[i], temp_v) + mm_num_concat = ((mm_out[::2] * 16 + mm_out[1::2]) + + bias[i].view(1, -1)) + per_token_quant = mm_num_concat * per_token_scale[start_idx // 2:( + start_idx + temp_v) // 2].view(-1, 1) + swiglu, gate = per_token_quant.chunk(2, dim=-1) + temp = swiglu * torch.sigmoid(swiglu) + temp = temp * gate + max_value = torch.max(torch.abs(temp), dim=-1).values + quant_scale_output_temp = 127 / max_value + quant_output[start_idx // 2:(start_idx + temp_v) // + 2] = torch.round(temp * + quant_scale_output_temp.reshape( + temp_v // 2, 1)).to(torch.int8) + quant_scale_output[start_idx // 2:(start_idx + temp_v) // + 2] = 1 / quant_scale_output_temp + start_idx += temp_v + return quant_output, quant_scale_output + + +def generate_non_decreasing_sequence(length, upper_limit): + # Generate random increasing sequence + random_increments = torch.randint(0, 128, (length, )) + sequence = torch.cumsum(random_increments, dim=0) + + # Make sure the last value is less than the upper limit + if sequence[-1] >= upper_limit: + scale_factor = upper_limit / sequence[-1] + sequence = (sequence * scale_factor).to(torch.int64) + return sequence + + +@torch.inference_mode() +def test_grouped_matmul_swiglu_quant_kernel(): + E = 16 + M = 512 + K = 7168 + N = 4096 + torch.npu.config.allow_internal_format = True + x = torch.randint(-5, 5, (M, K), dtype=torch.int8).npu() + weight_ori = torch.randint(-5, 5, (E, K, N), dtype=torch.int8) + weight_nz = torch_npu.npu_format_cast(weight_ori.npu().to(torch.float32), + 29) + pack_weight = torch_npu.npu_quantize(weight_nz, + torch.tensor([1.], device='npu'), + None, torch.quint4x2, -1, False) + + weight_scale = torch.randn(E, 1, N) + scale_np = weight_scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu() + pertoken_scale = torch.randn(M).to(torch.float32).npu() + group_list = generate_non_decreasing_sequence(E, M).npu() + bias = torch.zeros((E, N), dtype=torch.float32, + device="npu").uniform_(-5, 5) + + output_golden, output_scale_golden = gmm_swiglu_quant_golden_a8_w4( + x.cpu(), weight_ori, weight_scale, pertoken_scale.cpu(), bias.cpu(), + group_list.cpu()) + + output, output_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant( + x=x, + weight=pack_weight, + bias=bias, + group_list=group_list, + weight_scale=scale_uint64_tensor, + x_scale=pertoken_scale) + torch.testing.assert_close(output_golden, output.cpu(), atol=1, rtol=0.005) + torch.testing.assert_close(output_scale_golden, + output_scale.cpu(), + atol=1, + rtol=0.005) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats()