[feature] Add Custom Op grouped_matmul_swiglu_quant (#4431)
This PR introduces the `EXEC_NPU_CMD` macro, serving as an adapter layer to simplify the invocation of `aclnn` operators on Ascend NPUs. **Key Changes:** * **Adapter Layer:** Added `EXEC_NPU_CMD` macro and related dependencies to standardize `aclnn` calls. * **Operator Support:** Integrated `grouped_matmul_swiglu_quant` as a reference implementation to demonstrate the usage of the new macro. --- - vLLM version: v0.11.2 --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
1
.github/workflows/release_whl.yml
vendored
1
.github/workflows/release_whl.yml
vendored
@@ -98,6 +98,7 @@ jobs:
|
||||
--exclude libc_sec.so \
|
||||
--exclude "libascend*.so" \
|
||||
--exclude "libtorch*.so" \
|
||||
--exclude "libopapi.so" \
|
||||
--exclude "liberror_manager.so"
|
||||
done
|
||||
rm -f dist/*.whl
|
||||
|
||||
@@ -63,7 +63,8 @@ ascendc_library(vllm_ascend_kernels SHARED
|
||||
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
|
||||
|
||||
file(GLOB VLLM_ASCEND_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp)
|
||||
|
||||
include_directories(
|
||||
${pybind11_INCLUDE_DIRS}
|
||||
@@ -88,6 +89,7 @@ pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC})
|
||||
target_link_directories(
|
||||
vllm_ascend_C
|
||||
PRIVATE
|
||||
${TORCH_LIBRARY_DIRS}
|
||||
${TORCH_NPU_PATH}/lib/
|
||||
${ASCEND_HOME_PATH}/lib64
|
||||
)
|
||||
@@ -96,7 +98,7 @@ target_link_libraries(
|
||||
vllm_ascend_C
|
||||
PUBLIC
|
||||
${TORCH_LIBRARIES}
|
||||
libtorch_npu.so
|
||||
torch_npu
|
||||
vllm_ascend_kernels
|
||||
ascendcl
|
||||
tiling_api
|
||||
@@ -104,6 +106,7 @@ target_link_libraries(
|
||||
platform
|
||||
ascendalog
|
||||
dl
|
||||
opapi
|
||||
)
|
||||
|
||||
target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")
|
||||
|
||||
30
csrc/aclnn_torch_adapter/NPUBridge.cpp
Normal file
30
csrc/aclnn_torch_adapter/NPUBridge.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) 2020, Huawei Technologies Co., Ltd
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include "NPUBridge.h"
|
||||
|
||||
namespace vllm_ascend
|
||||
{
|
||||
NPUStorageImpl *NPUBridge::GetNpuStorageImpl(c10::StorageImpl *storageImpl)
|
||||
{
|
||||
return static_cast<NPUStorageImpl *>(storageImpl);
|
||||
}
|
||||
|
||||
NPUStorageImpl *NPUBridge::GetNpuStorageImpl(c10::Storage &&storage)
|
||||
{
|
||||
return static_cast<NPUStorageImpl *>(storage.unsafeGetStorageImpl());
|
||||
}
|
||||
|
||||
NPUStorageImpl *NPUBridge::GetNpuStorageImpl(const at::Tensor &tensor)
|
||||
{
|
||||
return static_cast<NPUStorageImpl *>(tensor.storage().unsafeGetStorageImpl());
|
||||
}
|
||||
|
||||
NPUStorageDesc &NPUBridge::GetNpuStorageImplDesc(const at::Tensor &tensor)
|
||||
{
|
||||
return static_cast<NPUStorageImpl *>(tensor.storage().unsafeGetStorageImpl())->npu_desc_;
|
||||
}
|
||||
}
|
||||
29
csrc/aclnn_torch_adapter/NPUBridge.h
Normal file
29
csrc/aclnn_torch_adapter/NPUBridge.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2020, Huawei Technologies Co., Ltd
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#pragma once
|
||||
#include <c10/core/StorageImpl.h>
|
||||
#include "NPUStorageImpl.h"
|
||||
|
||||
namespace vllm_ascend
|
||||
{
|
||||
|
||||
class NPUBridge
|
||||
{
|
||||
public:
|
||||
// at::tensor to NPUStorageImpl
|
||||
static NPUStorageImpl *GetNpuStorageImpl(const at::Tensor &tensor);
|
||||
|
||||
// c10::StorageImpl to NPUStorageImpl
|
||||
static NPUStorageImpl *GetNpuStorageImpl(c10::StorageImpl *storageImpl);
|
||||
|
||||
// c10::Storage to NPUStorageImpl
|
||||
static NPUStorageImpl *GetNpuStorageImpl(c10::Storage &&storage);
|
||||
|
||||
// tensor to NPUStorageDesc
|
||||
static NPUStorageDesc &GetNpuStorageImplDesc(const at::Tensor &tensor);
|
||||
};
|
||||
}
|
||||
52
csrc/aclnn_torch_adapter/NPUStorageImpl.cpp
Normal file
52
csrc/aclnn_torch_adapter/NPUStorageImpl.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2020, Huawei Technologies Co., Ltd
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include "NPUStorageImpl.h"
|
||||
|
||||
namespace vllm_ascend
|
||||
{
|
||||
|
||||
NPUStorageImpl::NPUStorageImpl(
|
||||
use_byte_size_t use_byte_size,
|
||||
size_t size_bytes,
|
||||
at::DataPtr data_ptr,
|
||||
at::Allocator *allocator,
|
||||
bool resizable) : c10::StorageImpl(use_byte_size,
|
||||
size_bytes,
|
||||
at::DataPtr(std::move(data_ptr)),
|
||||
allocator,
|
||||
resizable)
|
||||
{
|
||||
}
|
||||
|
||||
void NPUStorageImpl::release_resources()
|
||||
{
|
||||
StorageImpl::release_resources();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::StorageImpl> make_npu_storage_impl(
|
||||
c10::StorageImpl::use_byte_size_t,
|
||||
c10::SymInt size_bytes,
|
||||
c10::DataPtr data_ptr,
|
||||
c10::Allocator *allocator,
|
||||
bool resizable)
|
||||
{
|
||||
if (data_ptr == nullptr)
|
||||
{
|
||||
data_ptr = allocator->allocate(size_bytes.as_int_unchecked());
|
||||
}
|
||||
// Correctly create NPUStorageImpl object.
|
||||
c10::intrusive_ptr<c10::StorageImpl> npu_storage_impl = c10::make_intrusive<NPUStorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size_bytes.as_int_unchecked(),
|
||||
std::move(data_ptr),
|
||||
allocator,
|
||||
resizable);
|
||||
// There is no need to consider the NPUStorageDesc information, it will be carried out in the subsequent processing.
|
||||
return npu_storage_impl;
|
||||
}
|
||||
|
||||
}
|
||||
67
csrc/aclnn_torch_adapter/NPUStorageImpl.h
Normal file
67
csrc/aclnn_torch_adapter/NPUStorageImpl.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) 2020, Huawei Technologies Co., Ltd
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/StorageImpl.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/order_preserving_flat_hash_map.h>
|
||||
|
||||
#include "acl/acl_rt.h"
|
||||
#include "acl/acl_base.h"
|
||||
|
||||
namespace vllm_ascend
|
||||
{
|
||||
|
||||
struct NPUStorageDesc
|
||||
{
|
||||
public:
|
||||
struct use_byte_size_t
|
||||
{
|
||||
};
|
||||
|
||||
c10::SmallVector<int64_t, 5> base_sizes_;
|
||||
c10::SmallVector<int64_t, 5> base_strides_;
|
||||
c10::SmallVector<int64_t, 5> storage_sizes_;
|
||||
int64_t base_offset_ = 0;
|
||||
use_byte_size_t base_dtype_ = {};
|
||||
aclFormat origin_format_ = ACL_FORMAT_UNDEFINED;
|
||||
aclFormat npu_format_ = ACL_FORMAT_ND;
|
||||
// used to make CANN GE tensor from storagImpl
|
||||
caffe2::TypeMeta data_type_ = caffe2::TypeMeta::Make<uint8_t>();
|
||||
};
|
||||
|
||||
struct NPUStorageImpl : public c10::StorageImpl
|
||||
{
|
||||
explicit NPUStorageImpl(
|
||||
use_byte_size_t use_byte_size,
|
||||
size_t size_bytes,
|
||||
at::DataPtr data_ptr,
|
||||
at::Allocator *allocator,
|
||||
bool resizable);
|
||||
~NPUStorageImpl() override = default;
|
||||
|
||||
void release_resources() override;
|
||||
|
||||
NPUStorageDesc npu_desc_;
|
||||
|
||||
NPUStorageDesc get_npu_desc() const
|
||||
{
|
||||
return npu_desc_;
|
||||
}
|
||||
};
|
||||
|
||||
c10::intrusive_ptr<c10::StorageImpl> make_npu_storage_impl(
|
||||
c10::StorageImpl::use_byte_size_t,
|
||||
c10::SymInt size_bytes,
|
||||
c10::DataPtr data_ptr,
|
||||
c10::Allocator *allocator,
|
||||
bool resizable);
|
||||
|
||||
}
|
||||
591
csrc/aclnn_torch_adapter/op_api_common.h
Normal file
591
csrc/aclnn_torch_adapter/op_api_common.h
Normal file
@@ -0,0 +1,591 @@
|
||||
// Copyright (c) 2023 Huawei Technologies Co., Ltd
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OP_API_COMMON_ADAPTER
|
||||
#define OP_API_COMMON_ADAPTER
|
||||
|
||||
#include <torch/types.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <acl/acl_base.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <dlfcn.h>
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
|
||||
#include <torch_npu/csrc/framework/utils/OpAdapter.h>
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
#include "torch_npu/csrc/framework/OpCommand.h"
|
||||
#include "torch_npu/csrc/framework/interface/EnvVariables.h"
|
||||
#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
|
||||
#include "NPUBridge.h"
|
||||
#include "NPUStorageImpl.h"
|
||||
|
||||
#define NPU_NAME_SPACE at_npu::native
|
||||
using namespace at;
|
||||
|
||||
typedef struct aclOpExecutor aclOpExecutor;
|
||||
typedef struct aclTensor aclTensor;
|
||||
typedef struct aclScalar aclScalar;
|
||||
typedef struct aclIntArray aclIntArray;
|
||||
typedef struct aclFloatArray aclFloatArray;
|
||||
typedef struct aclBoolArray aclBoolArray;
|
||||
typedef struct aclTensorList aclTensorList;
|
||||
|
||||
typedef aclTensor *(*_aclCreateTensor)(
|
||||
const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type,
|
||||
const int64_t *stride, int64_t offset, aclFormat format,
|
||||
const int64_t *storage_dims, uint64_t storage_dims_num, void *tensor_data);
|
||||
typedef aclScalar *(*_aclCreateScalar)(void *value, aclDataType data_type);
|
||||
typedef aclIntArray *(*_aclCreateIntArray)(const int64_t *value, uint64_t size);
|
||||
typedef aclFloatArray *(*_aclCreateFloatArray)(const float *value,
|
||||
uint64_t size);
|
||||
typedef aclBoolArray *(*_aclCreateBoolArray)(const bool *value, uint64_t size);
|
||||
typedef aclTensorList *(*_aclCreateTensorList)(const aclTensor *const *value,
|
||||
uint64_t size);
|
||||
|
||||
typedef int (*_aclDestroyTensor)(const aclTensor *tensor);
|
||||
typedef int (*_aclDestroyScalar)(const aclScalar *scalar);
|
||||
typedef int (*_aclDestroyIntArray)(const aclIntArray *array);
|
||||
typedef int (*_aclDestroyFloatArray)(const aclFloatArray *array);
|
||||
typedef int (*_aclDestroyBoolArray)(const aclBoolArray *array);
|
||||
typedef int (*_aclDestroyTensorList)(const aclTensorList *array);
|
||||
|
||||
constexpr int kHashBufSize = 8192;
|
||||
constexpr int kHashBufMaxSize = kHashBufSize + 1024;
|
||||
extern thread_local char g_hashBuf[kHashBufSize];
|
||||
extern thread_local int g_hashOffset;
|
||||
|
||||
#ifdef MMCV_WITH_XLA
|
||||
#define DEVICE_TYPE at_npu::key::NativeDeviceType
|
||||
#else
|
||||
#define DEVICE_TYPE c10::DeviceType::PrivateUse1
|
||||
#endif
|
||||
|
||||
#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \
|
||||
_(at::ScalarType::Byte, ACL_UINT8) \
|
||||
_(at::ScalarType::Char, ACL_INT8) \
|
||||
_(at::ScalarType::Short, ACL_INT16) \
|
||||
_(at::ScalarType::Int, ACL_INT32) \
|
||||
_(at::ScalarType::Long, ACL_INT64) \
|
||||
_(at::ScalarType::Half, ACL_FLOAT16) \
|
||||
_(at::ScalarType::Float, ACL_FLOAT) \
|
||||
_(at::ScalarType::Double, ACL_DOUBLE) \
|
||||
_(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \
|
||||
_(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \
|
||||
_(at::ScalarType::Bool, ACL_BOOL) \
|
||||
_(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::BFloat16, ACL_BF16) \
|
||||
_(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::NumOptions, ACL_DT_UNDEFINED)
|
||||
|
||||
constexpr aclDataType kATenScalarTypeToAclDataTypeTable
|
||||
[static_cast<int64_t>(at::ScalarType::NumOptions) + 1] = {
|
||||
#define DEFINE_ENUM(_1, n) n,
|
||||
AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM)
|
||||
#undef DEFINE_ENUM
|
||||
};
|
||||
|
||||
#define GET_OP_API_FUNC(apiName) \
|
||||
reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName))
|
||||
|
||||
#define MEMCPY_TO_BUF(data_expression, size_expression) \
|
||||
if (g_hashOffset + (size_expression) > kHashBufSize) { \
|
||||
g_hashOffset = kHashBufMaxSize; \
|
||||
return; \
|
||||
} \
|
||||
memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \
|
||||
g_hashOffset += size_expression;
|
||||
|
||||
bool IsOpInputBaseFormat(const at::Tensor &tensor)
|
||||
{
|
||||
if (!tensor.is_privateuseone()) {
|
||||
return true;
|
||||
}
|
||||
const auto format = vllm_ascend::NPUBridge::GetNpuStorageImplDesc(tensor).npu_format_;
|
||||
return (format == ACL_FORMAT_ND) || (format == ACL_FORMAT_NCHW) || (format == ACL_FORMAT_NHWC) ||
|
||||
(format == ACL_FORMAT_NCDHW);
|
||||
}
|
||||
|
||||
inline const char *GetOpApiLibName(void) { return "libopapi.so"; }
|
||||
|
||||
inline const char *GetCustOpApiLibName(void) { return "libcust_opapi.so"; }
|
||||
|
||||
inline void *GetOpApiFuncAddrInLib(void *handler, const char *libName,
|
||||
const char *apiName) {
|
||||
auto funcAddr = dlsym(handler, apiName);
|
||||
return funcAddr;
|
||||
}
|
||||
|
||||
inline void *GetOpApiLibHandler(const char *libName) {
|
||||
auto handler = dlopen(libName, RTLD_LAZY);
|
||||
return handler;
|
||||
}
|
||||
|
||||
inline void *GetOpApiFuncAddr(const char *apiName) {
|
||||
static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName());
|
||||
if (custOpApiHandler != nullptr) {
|
||||
auto funcAddr =
|
||||
GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
|
||||
static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName());
|
||||
if (opApiHandler == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName);
|
||||
}
|
||||
|
||||
inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) {
|
||||
c10::Scalar expScalar;
|
||||
const at::Tensor *aclInput = &tensor;
|
||||
if (aclInput->scalar_type() == at::ScalarType::Double) {
|
||||
double value = *(double *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Long) {
|
||||
int64_t value = *(int64_t *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Float) {
|
||||
float value = *(float *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Int) {
|
||||
int value = *(int *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Half) {
|
||||
c10::Half value = *(c10::Half *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Bool) {
|
||||
int8_t value = *(int8_t *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::ComplexDouble) {
|
||||
c10::complex<double> value = *(c10::complex<double> *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::ComplexFloat) {
|
||||
c10::complex<float> value = *(c10::complex<float> *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::BFloat16) {
|
||||
c10::BFloat16 value = *(c10::BFloat16 *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
}
|
||||
return expScalar;
|
||||
}
|
||||
|
||||
inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) {
|
||||
at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory();
|
||||
int deviceIndex = 0;
|
||||
return cpuPinMemTensor.to(c10::Device(DEVICE_TYPE, deviceIndex),
|
||||
cpuPinMemTensor.scalar_type(), true, true);
|
||||
}
|
||||
|
||||
inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar,
|
||||
at::ScalarType scalar_data_type) {
|
||||
return CopyTensorHostToDevice(
|
||||
scalar_to_tensor(cpu_scalar).to(scalar_data_type));
|
||||
}
|
||||
|
||||
inline aclTensor *ConvertType(const at::Tensor &at_tensor) {
|
||||
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
|
||||
if (aclCreateTensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!at_tensor.defined()) {
|
||||
return nullptr;
|
||||
}
|
||||
at::ScalarType scalar_data_type = at_tensor.scalar_type();
|
||||
aclDataType acl_data_type =
|
||||
kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalar_data_type)];
|
||||
TORCH_CHECK(
|
||||
acl_data_type != ACL_DT_UNDEFINED,
|
||||
std::string(c10::toString(scalar_data_type)) + " has not been supported")
|
||||
c10::SmallVector<int64_t, 5> storageDims;
|
||||
// if acl_data_type is ACL_STRING, storageDims is empty.
|
||||
auto itemsize = at_tensor.itemsize();
|
||||
TORCH_CHECK(itemsize != 0, "When ConvertType, tensor item size cannot be zero.");
|
||||
|
||||
const auto dimNum = at_tensor.sizes().size();
|
||||
aclFormat format = ACL_FORMAT_ND;
|
||||
if (!IsOpInputBaseFormat(at_tensor)) {
|
||||
format = vllm_ascend::NPUBridge::GetNpuStorageImpl(at_tensor)->npu_desc_.npu_format_;
|
||||
if (acl_data_type != ACL_STRING) {
|
||||
storageDims = vllm_ascend::NPUBridge::GetNpuStorageImpl(at_tensor)->npu_desc_.storage_sizes_;
|
||||
}
|
||||
} else {
|
||||
switch (dimNum) {
|
||||
case 3:
|
||||
format = ACL_FORMAT_NCL;
|
||||
break;
|
||||
case 4:
|
||||
format = ACL_FORMAT_NCHW;
|
||||
break;
|
||||
case 5:
|
||||
format = ACL_FORMAT_NCDHW;
|
||||
break;
|
||||
default:
|
||||
format = ACL_FORMAT_ND;
|
||||
}
|
||||
if (acl_data_type != ACL_STRING) {
|
||||
storageDims.push_back(at_tensor.storage().nbytes() / itemsize);
|
||||
}
|
||||
}
|
||||
|
||||
if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
c10::Scalar expScalar = ConvertTensorToScalar(at_tensor);
|
||||
at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type);
|
||||
return aclCreateTensor(aclInput.sizes().data(), aclInput.sizes().size(),
|
||||
acl_data_type, aclInput.strides().data(),
|
||||
aclInput.storage_offset(), format,
|
||||
storageDims.data(), storageDims.size(),
|
||||
const_cast<void *>(aclInput.storage().data()));
|
||||
}
|
||||
|
||||
auto acl_tensor = aclCreateTensor(
|
||||
at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type,
|
||||
at_tensor.strides().data(), at_tensor.storage_offset(), format,
|
||||
storageDims.data(), storageDims.size(),
|
||||
const_cast<void *>(at_tensor.storage().data()));
|
||||
return acl_tensor;
|
||||
}
|
||||
|
||||
inline aclScalar *ConvertType(const at::Scalar &at_scalar) {
|
||||
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
|
||||
if (aclCreateScalar == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
at::ScalarType scalar_data_type = at_scalar.type();
|
||||
aclDataType acl_data_type =
|
||||
kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalar_data_type)];
|
||||
TORCH_CHECK(
|
||||
acl_data_type != ACL_DT_UNDEFINED,
|
||||
std::string(c10::toString(scalar_data_type)) + " has not been supported")
|
||||
aclScalar *acl_scalar = nullptr;
|
||||
switch (scalar_data_type) {
|
||||
case at::ScalarType::Double: {
|
||||
double value = at_scalar.toDouble();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Long: {
|
||||
int64_t value = at_scalar.toLong();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Bool: {
|
||||
bool value = at_scalar.toBool();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::ComplexDouble: {
|
||||
auto value = at_scalar.toComplexDouble();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
acl_scalar = nullptr;
|
||||
break;
|
||||
}
|
||||
return acl_scalar;
|
||||
}
|
||||
|
||||
inline aclIntArray *ConvertType(const at::IntArrayRef &at_array) {
|
||||
static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
|
||||
if (aclCreateIntArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto array = aclCreateIntArray(at_array.data(), at_array.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
inline aclBoolArray *ConvertType(const std::array<bool, N> &value) {
|
||||
static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray);
|
||||
if (aclCreateBoolArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto array = aclCreateBoolArray(value.data(), value.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
inline aclBoolArray *ConvertType(const at::ArrayRef<bool> &value) {
|
||||
static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray);
|
||||
if (aclCreateBoolArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto array = aclCreateBoolArray(value.data(), value.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
inline aclTensorList *ConvertType(const at::TensorList &at_tensor_list) {
|
||||
static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
|
||||
if (aclCreateTensorList == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<const aclTensor *> tensor_list(at_tensor_list.size());
|
||||
for (size_t i = 0; i < at_tensor_list.size(); i++) {
|
||||
tensor_list[i] = ConvertType(at_tensor_list[i]);
|
||||
}
|
||||
auto acl_tensor_list =
|
||||
aclCreateTensorList(tensor_list.data(), tensor_list.size());
|
||||
return acl_tensor_list;
|
||||
}
|
||||
|
||||
inline aclTensor *ConvertType(const c10::optional<at::Tensor> &opt_tensor) {
|
||||
if (opt_tensor.has_value() && opt_tensor.value().defined()) {
|
||||
return ConvertType(opt_tensor.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclIntArray *ConvertType(
|
||||
const c10::optional<at::IntArrayRef> &opt_array) {
|
||||
if (opt_array.has_value()) {
|
||||
return ConvertType(opt_array.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclScalar *ConvertType(const c10::optional<at::Scalar> &opt_scalar) {
|
||||
if (opt_scalar.has_value()) {
|
||||
return ConvertType(opt_scalar.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclDataType ConvertType(const at::ScalarType scalarType) {
|
||||
return kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalarType)];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T ConvertType(T value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr,
|
||||
std::index_sequence<I...>) {
|
||||
typedef int (*OpApiFunc)(
|
||||
typename std::decay<decltype(std::get<I>(params))>::type...);
|
||||
auto func = reinterpret_cast<OpApiFunc>(opApiAddr);
|
||||
return func;
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return ConvertToOpApiFunc(params, opApiAddr,
|
||||
std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
inline void Release(aclTensor *p) {
|
||||
static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor);
|
||||
if (aclDestroyTensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
aclDestroyTensor(p);
|
||||
}
|
||||
|
||||
inline void Release(aclScalar *p) {
|
||||
static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar);
|
||||
if (aclDestroyScalar == nullptr) {
|
||||
return;
|
||||
}
|
||||
aclDestroyScalar(p);
|
||||
}
|
||||
|
||||
inline void Release(aclIntArray *p) {
|
||||
static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray);
|
||||
if (aclDestroyIntArray == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyIntArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclBoolArray *p) {
|
||||
static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray);
|
||||
if (aclDestroyBoolArray == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyBoolArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclTensorList *p) {
|
||||
static const auto aclDestroyTensorList =
|
||||
GET_OP_API_FUNC(aclDestroyTensorList);
|
||||
if (aclDestroyTensorList == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyTensorList(p);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Release(T value) {
|
||||
(void)value;
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
void CallRelease(Tuple t, std::index_sequence<I...>) {
|
||||
(void)std::initializer_list<int>{(Release(std::get<I>(t)), 0)...};
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
void ReleaseConvertTypes(Tuple &t) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
CallRelease(t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
constexpr auto ConvertTypes(Ts &... args) {
|
||||
return std::make_tuple(ConvertType(args)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple, size_t... I>
|
||||
auto call(Function f, Tuple t, std::index_sequence<I...>) {
|
||||
return f(std::get<I>(t)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple>
|
||||
auto call(Function f, Tuple t) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return call(f, t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
void AddParamToBuf(const std::array<bool, N> &value) {
|
||||
MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddParamToBuf(const T &value) {
|
||||
MEMCPY_TO_BUF(&value, sizeof(T));
|
||||
}
|
||||
|
||||
void AddParamToBuf(const at::Tensor &);
|
||||
void AddParamToBuf(const at::Scalar &);
|
||||
void AddParamToBuf(const at::IntArrayRef &);
|
||||
void AddParamToBuf(const at::ArrayRef<bool> &);
|
||||
void AddParamToBuf(const at::TensorList &);
|
||||
void AddParamToBuf(const c10::optional<at::Tensor> &);
|
||||
void AddParamToBuf(const c10::optional<at::IntArrayRef> &);
|
||||
void AddParamToBuf(const c10::optional<at::Scalar> &);
|
||||
void AddParamToBuf(const at::ScalarType);
|
||||
void AddParamToBuf(const string &);
|
||||
void AddParamToBuf();
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void AddParamToBuf(const T &arg, Args &... args) {
|
||||
AddParamToBuf(arg);
|
||||
AddParamToBuf(args...);
|
||||
}
|
||||
|
||||
uint64_t CalcHashId();
|
||||
typedef int (*InitHugeMemThreadLocal)(void *, bool);
|
||||
typedef void (*UnInitHugeMemThreadLocal)(void *, bool);
|
||||
typedef void (*ReleaseHugeMem)(void *, bool);
|
||||
|
||||
#define EXEC_NPU_CMD(aclnn_api, ...) \
|
||||
do { \
|
||||
static const auto getWorkspaceSizeFuncAddr = \
|
||||
GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
|
||||
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
|
||||
static const auto initMemAddr = \
|
||||
GetOpApiFuncAddr("InitHugeMemThreadLocal"); \
|
||||
static const auto unInitMemAddr = \
|
||||
GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \
|
||||
static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \
|
||||
TORCH_CHECK( \
|
||||
getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \
|
||||
#aclnn_api, " or ", #aclnn_api "GetWorkspaceSize", " not in ", \
|
||||
GetOpApiLibName(), ", or ", GetOpApiLibName(), "not found."); \
|
||||
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
|
||||
uint64_t workspace_size = 0; \
|
||||
uint64_t *workspace_size_addr = &workspace_size; \
|
||||
aclOpExecutor *executor = nullptr; \
|
||||
aclOpExecutor **executor_addr = &executor; \
|
||||
InitHugeMemThreadLocal initMemFunc = \
|
||||
reinterpret_cast<InitHugeMemThreadLocal>(initMemAddr); \
|
||||
UnInitHugeMemThreadLocal unInitMemFunc = \
|
||||
reinterpret_cast<UnInitHugeMemThreadLocal>(unInitMemAddr); \
|
||||
if (initMemFunc) { \
|
||||
initMemFunc(nullptr, false); \
|
||||
} \
|
||||
auto converted_params = \
|
||||
ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \
|
||||
static auto getWorkspaceSizeFunc = \
|
||||
ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
|
||||
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
|
||||
TORCH_CHECK(workspace_status == 0, \
|
||||
"call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \
|
||||
void *workspace_addr = nullptr; \
|
||||
if (workspace_size != 0) { \
|
||||
at::TensorOptions options = \
|
||||
at::TensorOptions(torch_npu::utils::get_npu_device_type()); \
|
||||
auto workspace_tensor = \
|
||||
at::empty({workspace_size}, options.dtype(kByte)); \
|
||||
workspace_addr = const_cast<void *>(workspace_tensor.storage().data()); \
|
||||
} \
|
||||
auto acl_call = [converted_params, workspace_addr, workspace_size, \
|
||||
acl_stream, executor]() -> int { \
|
||||
typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, \
|
||||
const aclrtStream); \
|
||||
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
|
||||
auto api_ret = \
|
||||
opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
|
||||
TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", \
|
||||
aclGetRecentErrMsg()); \
|
||||
ReleaseConvertTypes(converted_params); \
|
||||
ReleaseHugeMem releaseMemFunc = \
|
||||
reinterpret_cast<ReleaseHugeMem>(releaseMemAddr); \
|
||||
if (releaseMemFunc) { \
|
||||
releaseMemFunc(nullptr, false); \
|
||||
} \
|
||||
return api_ret; \
|
||||
}; \
|
||||
at_npu::native::OpCommand cmd; \
|
||||
cmd.Name(#aclnn_api); \
|
||||
cmd.SetCustomHandler(acl_call); \
|
||||
cmd.Run(); \
|
||||
if (unInitMemFunc) { \
|
||||
unInitMemFunc(nullptr, false); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
@@ -27,12 +27,14 @@
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
#include "mla_preprocess/op_host/mla_preprocess.h"
|
||||
#include "aclnn_torch_adapter/op_api_common.h"
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace vllm_ascend {
|
||||
const int64_t INT4_NUMS_IN_INT32 = 8;
|
||||
void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping, aclrtStream stream) {
|
||||
torch::Device src_device = src.device();
|
||||
@@ -520,6 +522,36 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(
|
||||
const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale,
|
||||
const at::Tensor &group_list, const c10::optional<at::Tensor> &bias, const c10::optional<at::Tensor> &offset)
|
||||
{
|
||||
int m = x.sizes()[0];
|
||||
int n = weight.sizes()[2];
|
||||
bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt;
|
||||
if (is_a8w4) {
|
||||
n *= INT4_NUMS_IN_INT32;
|
||||
}
|
||||
|
||||
at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char));
|
||||
at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float));
|
||||
at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float));
|
||||
|
||||
EXEC_NPU_CMD(
|
||||
aclnnGroupedMatmulSwigluQuantWeightNZ,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
offset,
|
||||
weight_scale,
|
||||
x_scale,
|
||||
group_list,
|
||||
output,
|
||||
output_scale,
|
||||
output_offset);
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -576,4 +608,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
|
||||
ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()");
|
||||
ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks);
|
||||
|
||||
ops.def(
|
||||
"grouped_matmul_swiglu_quant(Tensor x, Tensor weight, Tensor weight_scale, Tensor x_scale,"
|
||||
" Tensor group_list, *, Tensor? bias=None,"
|
||||
" Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)");
|
||||
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
|
||||
namespace vllm_ascend {
|
||||
namespace meta {
|
||||
|
||||
const int64_t INT4_NUMS_IN_INT32 = 8;
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
|
||||
at::Tensor &positions,
|
||||
at::Tensor &query,
|
||||
@@ -114,6 +114,22 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
|
||||
return {q_out0, kv_cache_out0, q_out1, kv_cache_out1};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(
|
||||
const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale,
|
||||
const at::Tensor &group_list, const c10::optional<at::Tensor> &bias, const c10::optional<at::Tensor> &offset)
|
||||
{
|
||||
int m = x.sizes()[0];
|
||||
int n = weight.sizes()[2];
|
||||
bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt;
|
||||
if (is_a8w4) {
|
||||
n *= INT4_NUMS_IN_INT32;
|
||||
}
|
||||
at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char));
|
||||
at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float));
|
||||
at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float));
|
||||
return {output, output_scale, output_offset};
|
||||
}
|
||||
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
@@ -132,5 +148,7 @@ namespace {
|
||||
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
||||
// MLA preprocess
|
||||
ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess);
|
||||
// grouped_matmul_swiglu_quant meta implementation
|
||||
ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant);
|
||||
}
|
||||
}
|
||||
|
||||
175
tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py
Normal file
175
tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
def x_int8_to_x_int4(x: torch.Tensor):
|
||||
m, k = x.shape
|
||||
x_high_4bit = torch.floor(x.to(torch.float16) // 16).to(torch.int8)
|
||||
x_low_4bit = (
|
||||
torch.bitwise_and(x.view(torch.int16), 0x0f0f).view(torch.int8) - 8)
|
||||
x_int4 = torch.empty((2 * m, k), dtype=torch.int8)
|
||||
x_int4[::2, :] = x_high_4bit
|
||||
x_int4[1::2, :] = x_low_4bit
|
||||
return x_int4
|
||||
|
||||
|
||||
def custom_mm(x: torch.Tensor, weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor, m: int):
|
||||
"""
|
||||
Performing Quantized GMM (General Matrix Multiplication) Operation
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor with shape (m, k).
|
||||
weight (torch.Tensor): Weight tensor with shape (k, n).
|
||||
weight_scale (torch.Tensor): Scaling factor for each channel.
|
||||
- In perGroup scenario: Shape is (k_group_num, n). Note: When k_group_num == 1, it is a perChannel scenario.
|
||||
- In perChannel scenario: Shape is (n).
|
||||
m (int): Number of tokens (number of rows in x).
|
||||
Returns:
|
||||
mm_out(fp16): Result of MatMul + perGroup or perChannel dequantization.
|
||||
"""
|
||||
# Perform matrix multiplication with int32 precision
|
||||
k, n = weight.shape
|
||||
mm_out = torch.zeros((m, n), dtype=torch.float16)
|
||||
# perGroup scenario
|
||||
if len(weight_scale.shape) == 2 and weight_scale.shape[0] != 1:
|
||||
k_group = weight_scale.shape[0]
|
||||
per_group_ele = k // k_group
|
||||
x_grouped = x.view(-1, k_group, per_group_ele).transpose(0, 1)
|
||||
weight_grouped = weight.view(k_group, per_group_ele, n)
|
||||
|
||||
c_temp = torch.bmm(x_grouped.to(torch.int32),
|
||||
weight_grouped.to(torch.int32)).to(torch.float16)
|
||||
for k_idx in range(k_group):
|
||||
mm_out += (c_temp[k_idx] *
|
||||
weight_scale[k_idx].view(1, -1).to(torch.float16)).to(
|
||||
torch.float16)
|
||||
# perChannel scenario
|
||||
elif len(weight_scale.shape) == 1 or (len(weight_scale.shape) == 2
|
||||
and weight_scale.shape[0] == 1):
|
||||
c_temp = torch.matmul(x.to(torch.int32),
|
||||
weight.to(torch.int32)).to(torch.float32)
|
||||
mm_out = c_temp * weight_scale.view(1, -1).to(torch.float16)
|
||||
return mm_out.to(torch.float32)
|
||||
|
||||
|
||||
def gmm_swiglu_quant_golden_a8_w4(x: torch.Tensor, weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
per_token_scale: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
group_list: torch.Tensor):
|
||||
"""
|
||||
Process the input data by group and call the GMM_Swiglu_quant function for quantization computation.
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor with shape (M, K), type INT8.
|
||||
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N), data type INT8 but data range INT4, representing INT4 values.
|
||||
weight_scale (torch.Tensor): Scaling factor for each channel.
|
||||
- In perGroup scenario: shape (E, k_group_num, N).
|
||||
- In perChannel scenario: shape (E, N).
|
||||
per_token_scale (torch.Tensor): Scaling factor for each token, shape (M, ).
|
||||
bias: torch.Tensor,
|
||||
group_list (list): List defining the number of tokens in each group.
|
||||
Returns:
|
||||
quant_output (torch.Tensor): Quantized output tensor with shape (M, N // 2).
|
||||
quant_scale_output (torch.Tensor): Quantization scaling factor, shape (M, ).
|
||||
"""
|
||||
M, N = x.shape[0], weight.shape[2]
|
||||
quant_output = torch.zeros(M, N // 2).to(torch.int8)
|
||||
quant_scale_output = torch.zeros(M).to(torch.float32)
|
||||
# Preprocessing X_INT8 -> X_INT4
|
||||
x_int4 = x_int8_to_x_int4(x)
|
||||
start_idx = 0
|
||||
# Number of tokens in the previous group
|
||||
pre_v = 0
|
||||
group_list = group_list.tolist()
|
||||
# Traverse group_list and process data by group
|
||||
for i, v in enumerate(group_list):
|
||||
curr_v = v
|
||||
# Calculate the number of tokens in the current group " * 2 " because 1 row of Int8--> 2 rows of Int4
|
||||
temp_v = int((curr_v - pre_v) * 2)
|
||||
# Update the number of tokens in the previous group
|
||||
pre_v = curr_v
|
||||
if (temp_v > 0):
|
||||
mm_out = custom_mm(x_int4[int(start_idx):int(start_idx + temp_v)],
|
||||
weight[i], weight_scale[i], temp_v)
|
||||
mm_num_concat = ((mm_out[::2] * 16 + mm_out[1::2]) +
|
||||
bias[i].view(1, -1))
|
||||
per_token_quant = mm_num_concat * per_token_scale[start_idx // 2:(
|
||||
start_idx + temp_v) // 2].view(-1, 1)
|
||||
swiglu, gate = per_token_quant.chunk(2, dim=-1)
|
||||
temp = swiglu * torch.sigmoid(swiglu)
|
||||
temp = temp * gate
|
||||
max_value = torch.max(torch.abs(temp), dim=-1).values
|
||||
quant_scale_output_temp = 127 / max_value
|
||||
quant_output[start_idx // 2:(start_idx + temp_v) //
|
||||
2] = torch.round(temp *
|
||||
quant_scale_output_temp.reshape(
|
||||
temp_v // 2, 1)).to(torch.int8)
|
||||
quant_scale_output[start_idx // 2:(start_idx + temp_v) //
|
||||
2] = 1 / quant_scale_output_temp
|
||||
start_idx += temp_v
|
||||
return quant_output, quant_scale_output
|
||||
|
||||
|
||||
def generate_non_decreasing_sequence(length, upper_limit):
|
||||
# Generate random increasing sequence
|
||||
random_increments = torch.randint(0, 128, (length, ))
|
||||
sequence = torch.cumsum(random_increments, dim=0)
|
||||
|
||||
# Make sure the last value is less than the upper limit
|
||||
if sequence[-1] >= upper_limit:
|
||||
scale_factor = upper_limit / sequence[-1]
|
||||
sequence = (sequence * scale_factor).to(torch.int64)
|
||||
return sequence
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_grouped_matmul_swiglu_quant_kernel():
|
||||
E = 16
|
||||
M = 512
|
||||
K = 7168
|
||||
N = 4096
|
||||
torch.npu.config.allow_internal_format = True
|
||||
x = torch.randint(-5, 5, (M, K), dtype=torch.int8).npu()
|
||||
weight_ori = torch.randint(-5, 5, (E, K, N), dtype=torch.int8)
|
||||
weight_nz = torch_npu.npu_format_cast(weight_ori.npu().to(torch.float32),
|
||||
29)
|
||||
pack_weight = torch_npu.npu_quantize(weight_nz,
|
||||
torch.tensor([1.], device='npu'),
|
||||
None, torch.quint4x2, -1, False)
|
||||
|
||||
weight_scale = torch.randn(E, 1, N)
|
||||
scale_np = weight_scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
|
||||
pertoken_scale = torch.randn(M).to(torch.float32).npu()
|
||||
group_list = generate_non_decreasing_sequence(E, M).npu()
|
||||
bias = torch.zeros((E, N), dtype=torch.float32,
|
||||
device="npu").uniform_(-5, 5)
|
||||
|
||||
output_golden, output_scale_golden = gmm_swiglu_quant_golden_a8_w4(
|
||||
x.cpu(), weight_ori, weight_scale, pertoken_scale.cpu(), bias.cpu(),
|
||||
group_list.cpu())
|
||||
|
||||
output, output_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant(
|
||||
x=x,
|
||||
weight=pack_weight,
|
||||
bias=bias,
|
||||
group_list=group_list,
|
||||
weight_scale=scale_uint64_tensor,
|
||||
x_scale=pertoken_scale)
|
||||
torch.testing.assert_close(output_golden, output.cpu(), atol=1, rtol=0.005)
|
||||
torch.testing.assert_close(output_scale_golden,
|
||||
output_scale.cpu(),
|
||||
atol=1,
|
||||
rtol=0.005)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
Reference in New Issue
Block a user