[Kernel] add custom op GmmSwigluQuantWeightNzTensorList (#3804)

### What this PR does / why we need it?

This PR introduces support for adding custom CANN `aclnn` ops to
`vllm-ascend`, allowing users to define and use their own custom
operators.

Key changes include:
- Building and installing custom ops into the `vllm-ascend`-specified
directory
- Binding the `aclnn` op interface to the `torch.ops._C_ascend` module
- Enabling invocation of these ops within `vllm-ascend`

This PR includes a sample custom op:
`aclnnGroupedMatmulSwigluQuantWeightNzTensorList`, which is adapted from
the CANN operator
[`aclnnGroupedMatmulSwigluQuantWeightNZ`](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/API/aolapi/context/aclnnGroupedMatmulSwigluQuantWeightNZ.md).
Its input parameters `weight` and `weight_scale` now accept
`list[torch.Tensor]` (i.e., `at::TensorList`).

### Does this PR introduce _any_ user-facing change?

No.


- vLLM version: v0.11.2

---------

Signed-off-by: QianChenxi <chenxi.qian.cq@outlook.com>
This commit is contained in:
Chenxi Qian
2025-11-28 18:06:39 +08:00
committed by GitHub
parent 3199fe8350
commit 554f16ae1f
50 changed files with 6934 additions and 7 deletions

View File

@@ -0,0 +1,47 @@
/**
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling.h
* \brief
*/
#pragma once
#include <vector>
#include <graph/tensor.h>
#include "data_copy_transpose_tiling_def.h"
namespace optiling {
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
optiling::CopyTransposeTiling &tiling)
{
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
tiling.set_dstShapeB(dstShapeInfo[0]);
tiling.set_dstShapeN(dstShapeInfo[1]);
tiling.set_dstShapeS(dstShapeInfo[2]);
tiling.set_dstShapeH(dstShapeInfo[3]);
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
tiling.set_srcShapeB(srcShapeInfo[0]);
tiling.set_srcShapeN(srcShapeInfo[1]);
tiling.set_srcShapeS(srcShapeInfo[2]);
tiling.set_srcShapeHN(srcShapeInfo[3]);
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
}
} // namespace optiling

View File

@@ -0,0 +1,43 @@
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling_def.h
* \brief
*/
#pragma once
#include <cstdint>
#include <register/tilingdata_base.h>
namespace optiling {
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
} // namespace optiling

View File

@@ -0,0 +1,225 @@
/**
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_base.h
* \brief
*/
#pragma once
#include <sstream>
#include <exe_graph/runtime/tiling_context.h>
#include <graph/utils/type_utils.h>
#include <tiling/platform/platform_ascendc.h>
#include "log/ops_log.h"
#ifdef ASCENDC_OP_TEST
#define ASCENDC_EXTERN_C extern "C"
#else
#define ASCENDC_EXTERN_C
#endif
namespace optiling {
struct AiCoreParams {
uint64_t ubSize;
uint64_t blockDim;
uint64_t aicNum;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
};
struct FlashAttentionScoreGradCompileInfo {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
};
class TilingBaseClass {
public:
TilingBaseClass() = default;
explicit TilingBaseClass(gert::TilingContext *context) : context_(context)
{
}
virtual ~TilingBaseClass() = default;
// Tiling执行框架
// 1、GRAPH_SUCCESS: 成功并且不需要继续执行后续Tiling类的实现
// 2、GRAPH_FAILED: 失败中止整个Tiling流程
// 3、GRAPH_PARAM_INVALID: 本类不支持需要继续往下执行其他Tiling类的实现
ge::graphStatus DoTiling()
{
auto ret = GetShapeAttrsInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetPlatformInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
if (!IsCapable()) {
return ge::GRAPH_PARAM_INVALID;
}
ret = DoOpTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = DoLibApiTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetWorkspaceSize();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = PostTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
context_->SetTilingKey(GetTilingKey());
DumpTilingInfo();
return ge::GRAPH_SUCCESS;
}
// 更新 context
virtual void Reset(gert::TilingContext *context)
{
context_ = context;
}
protected:
virtual bool IsCapable() = 0;
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
virtual ge::graphStatus GetPlatformInfo() = 0;
// 2、获取INPUT/OUTPUT/ATTR信息
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
// 3、计算数据切分TilingData
virtual ge::graphStatus DoOpTiling() = 0;
// 4、计算高阶API的TilingData
virtual ge::graphStatus DoLibApiTiling() = 0;
// 5、计算TilingKey
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
// 6、计算Workspace 大小
virtual ge::graphStatus GetWorkspaceSize() = 0;
// 7、保存Tiling数据
virtual ge::graphStatus PostTiling() = 0;
// 8、Dump Tiling数据
virtual void DumpTilingInfo()
{
int32_t enable = AlogCheckDebugLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
if (enable != 1) {
return;
}
auto buf = (uint32_t *)context_->GetRawTilingData()->GetData();
auto bufLen = context_->GetRawTilingData()->GetDataSize();
std::ostringstream oss;
oss << "Start to dump tiling info. tilingkey:" << GetTilingKey() << ", tiling data size:" << bufLen
<< ", content:";
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
oss << *(buf + i) << ",";
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
OPS_LOG_D(context_, "%s", oss.str().c_str());
oss.str("");
}
}
OPS_LOG_D(context_, "%s", oss.str().c_str());
}
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
{
uint32_t ration;
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
return sliceNum;
}
ration = aivCoreNum / aicCoreNum;
return (sliceNum + (ration - 1)) / ration;
}
template <typename T> [[nodiscard]] std::string GetShapeDebugStr(const T &shape) const
{
std::ostringstream oss;
oss << "[";
if (shape.GetDimNum() > 0) {
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
oss << shape.GetDim(i) << ", ";
}
oss << shape.GetDim(shape.GetDimNum() - 1);
}
oss << "]";
return oss.str();
}
[[nodiscard]] std::string GetTensorDebugStr(const gert::StorageShape *shape,
const gert::CompileTimeTensorDesc *tensor)
{
if (shape == nullptr || tensor == nullptr) {
return "nil ";
}
std::ostringstream oss;
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
oss << "(format: "
<< ge::TypeUtils::FormatToSerialString(
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
<< "),";
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
return oss.str();
}
[[nodiscard]] std::string GetTilingContextDebugStr()
{
std::ostringstream oss;
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
oss << "input" << i << ": ";
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
}
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
oss << "output" << i << ": ";
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
}
return oss.str();
}
[[nodiscard]] std::string GetTilingDataDebugStr() const
{
auto rawTilingData = context_->GetRawTilingData();
auto rawTilingDataSize = rawTilingData->GetDataSize();
auto data = reinterpret_cast<const int32_t *>(rawTilingData->GetData());
size_t len = rawTilingDataSize / sizeof(int32_t);
std::ostringstream oss;
for (size_t i = 0; i < len; i++) {
oss << data[i] << ", ";
}
return oss.str();
}
protected:
gert::TilingContext *context_ = nullptr;
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
uint32_t blockDim_{0};
uint64_t workspaceSize_{0};
uint64_t tilingKey_{0};
AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0};
};
} // namespace optiling

View File

@@ -0,0 +1,162 @@
/**
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_templates_registry.h
* \brief
*/
#pragma once
#include <map>
#include <string>
#include <memory>
#include <exe_graph/runtime/tiling_context.h>
#include "tiling/tiling_base.h"
#include "log/ops_log.h"
#include "error/ops_error.h"
namespace optiling {
template <typename T> std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext *context)
{
return std::unique_ptr<T>(new (std::nothrow) T(context));
}
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext *);
class TilingCases {
public:
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
{
}
template <typename T> void AddTiling(int32_t priority)
{
OPS_ERR_IF(cases_.find(priority) != cases_.end(),
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "There are duplicate registrations."), return);
cases_[priority] = TILING_CLASS<T>;
OPS_ERR_IF(
cases_[priority] == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling func failed, please check the class name."),
return);
}
const std::map<int32_t, TilingClassCase> &GetTilingCases()
{
return cases_;
}
private:
std::map<int32_t, TilingClassCase> cases_;
const std::string op_type_;
};
class TilingRegistry {
public:
TilingRegistry() = default;
#ifdef ASCENDC_OP_TEST
static TilingRegistry &GetInstance();
#else
static TilingRegistry &GetInstance()
{
static TilingRegistry registry_impl_;
return registry_impl_;
}
#endif
std::shared_ptr<TilingCases> RegisterOp(const std::string &op_type)
{
if (registry_map_.find(op_type) == registry_map_.end()) {
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
}
OPS_ERR_IF(registry_map_[op_type] == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Register tiling func failed, please check the class name."),
return nullptr);
return registry_map_[op_type];
}
ge::graphStatus DoTilingImpl(gert::TilingContext *context)
{
const char *op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
auto tilingTemplate = it->second(context);
if (tilingTemplate != nullptr) {
ge::graphStatus status = tilingTemplate->DoTiling();
if (status != ge::GRAPH_PARAM_INVALID) {
OPS_LOG_D(context, "Do general op tiling success priority=%d", it->first);
return status;
}
OPS_LOG_D(context, "Ignore general op tiling priority=%d", it->first);
}
}
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
ge::graphStatus DoTilingImpl(gert::TilingContext *context, const std::vector<int32_t> &priorities)
{
const char *op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto priorityId : priorities) {
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
if (templateFunc != nullptr) {
ge::graphStatus status = templateFunc->DoTiling();
if (status == ge::GRAPH_SUCCESS) {
OPS_LOG_D(context, "Do general op tiling success priority=%d", priorityId);
return status;
}
OPS_LOG_D(context, "Ignore general op tiling priority=%d", priorityId);
}
}
return ge::GRAPH_FAILED;
}
const std::map<int32_t, TilingClassCase> &GetTilingTemplates(const std::string &op_type)
{
OPS_ERR_IF(registry_map_.find(op_type) == registry_map_.end(),
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Get op tiling func failed, please check the op name."),
return empty_tiling_case_);
return registry_map_[op_type]->GetTilingCases();
}
private:
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
const std::map<int32_t, TilingClassCase> empty_tiling_case_ {};
};
class Register {
public:
explicit Register(std::string op_type) : op_type_(std::move(op_type))
{
}
template <typename T> Register &tiling(int32_t priority)
{
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
OPS_ERR_IF(tilingCases == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling failed, please the op name."),
return *this);
tilingCases->AddTiling<T>(priority);
return *this;
}
private:
const std::string op_type_;
};
// op_type: 算子名称, class_name: 注册的 tiling 类,
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
static Register VAR_UNUSED##op_type_##class_name##priority_register = Register(op_type).tiling<class_name>(priority)
} // namespace optiling

View File

@@ -0,0 +1,136 @@
/**
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_type.h
* \brief
*/
#pragma once
#include <cstdint>
namespace optiling {
enum class AxisEnum {
B = 0,
N2 = 1,
G = 2,
S1 = 3,
S2 = 4,
D = 5,
NONE = 9,
};
enum class DtypeEnum {
FLOAT16 = 0,
FLOAT32 = 1,
BFLOAT16 = 2,
FLOAT16_PRECISION = 3,
};
enum class PerformanceOrientedEnum {
BIG_BUFFER = 1,
BIG_DOUBLE_BUFFER = 2,
};
enum class MatmulConfig {
NULL_CONFIG = 0,
NORMAL_CONFIG = 1,
MDL_CONFIG = 2
};
enum class PseConfig {
NO_PSE = 0,
EXIST_PSE = 1
};
enum class AttenMaskConfig {
NO_ATTEN_MASK = 0,
EXIST_ATTEN_MASK = 1
};
enum class DropOutConfig {
NO_DROP_OUT = 0,
EXIST_DROP_OUT = 1
};
enum class CubeFormatEnum {
ND = 0,
NZ = 1
};
enum class LayoutEnum {
BSND = 0,
SBND = 1,
BNSD = 2,
TND = 3
};
enum class CubeInputSourceEnum {
GM = 0,
L1 = 1
};
enum class OptionEnum {
DISABLE = 0,
ENABLE = 1
};
enum class SparseEnum {
ALL = 0,
NONE = 1,
ANY = 2,
CAUSAL = 3,
BAND = 4,
PREFIX = 5,
BAND_COMPRESS = 6,
RIGHT_DOWN_CAUSAL = 7,
RIGHT_DOWN_CAUSAL_BAND = 8,
BAND_LEFT_UP_CAUSAL = 9
};
constexpr uint64_t RecursiveSum()
{
return 0;
}
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
{
return static_cast<uint64_t>(templateId) + 10 * RecursiveSum(templateIds...);
}
// TilingKey 的生成规则:
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key包含以下关键参数从低位到高位依次是Ub0, Ub1,
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
// 表示Ub核内切分的轴使用枚举AxisEnum表示因为我们允许最多切分两根轴所以存在UB0和UB1如果没有UB核内切分
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
// Block: 表示UB用来分核的轴使用枚举AxisEnum表示占一个十进制位;
// DataType: 表示当前tiling key支持的输入输出的数据类型使用枚举SupportedDtype来表示占一个十进制位
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示占一个十进制位
// Sparse: 表示当前tiling key是否支持Sparse使用枚举SparseCapability表示占一个十进制位
// 其余特化场景,定义自己的位域和值
// usage: get tilingKey from inputed types
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
{
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}
// usage: get tilingKey from inputed types
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
SparseEnum::sparse))
} // namespace optiling