[feature] Add Custom Op grouped_matmul_swiglu_quant (#4431)
This PR introduces the `EXEC_NPU_CMD` macro, serving as an adapter layer to simplify the invocation of `aclnn` operators on Ascend NPUs. **Key Changes:** * **Adapter Layer:** Added `EXEC_NPU_CMD` macro and related dependencies to standardize `aclnn` calls. * **Operator Support:** Integrated `grouped_matmul_swiglu_quant` as a reference implementation to demonstrate the usage of the new macro. --- - vLLM version: v0.11.2 --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -27,12 +27,14 @@
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
#include "mla_preprocess/op_host/mla_preprocess.h"
|
||||
#include "aclnn_torch_adapter/op_api_common.h"
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace vllm_ascend {
|
||||
const int64_t INT4_NUMS_IN_INT32 = 8;
|
||||
void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping, aclrtStream stream) {
|
||||
torch::Device src_device = src.device();
|
||||
@@ -520,6 +522,36 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(
|
||||
const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale,
|
||||
const at::Tensor &group_list, const c10::optional<at::Tensor> &bias, const c10::optional<at::Tensor> &offset)
|
||||
{
|
||||
int m = x.sizes()[0];
|
||||
int n = weight.sizes()[2];
|
||||
bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt;
|
||||
if (is_a8w4) {
|
||||
n *= INT4_NUMS_IN_INT32;
|
||||
}
|
||||
|
||||
at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char));
|
||||
at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float));
|
||||
at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float));
|
||||
|
||||
EXEC_NPU_CMD(
|
||||
aclnnGroupedMatmulSwigluQuantWeightNZ,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
offset,
|
||||
weight_scale,
|
||||
x_scale,
|
||||
group_list,
|
||||
output,
|
||||
output_scale,
|
||||
output_offset);
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -576,4 +608,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
|
||||
ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()");
|
||||
ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks);
|
||||
|
||||
ops.def(
|
||||
"grouped_matmul_swiglu_quant(Tensor x, Tensor weight, Tensor weight_scale, Tensor x_scale,"
|
||||
" Tensor group_list, *, Tensor? bias=None,"
|
||||
" Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)");
|
||||
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user