[Kernel] add custom op GmmSwigluQuantWeightNzTensorList (#3804)

### What this PR does / why we need it?

This PR introduces support for adding custom CANN `aclnn` ops to
`vllm-ascend`, allowing users to define and use their own custom
operators.

Key changes include:
- Building and installing custom ops into the `vllm-ascend`-specified
directory
- Binding the `aclnn` op interface to the `torch.ops._C_ascend` module
- Enabling invocation of these ops within `vllm-ascend`

This PR includes a sample custom op:
`aclnnGroupedMatmulSwigluQuantWeightNzTensorList`, which is adapted from
the CANN operator
[`aclnnGroupedMatmulSwigluQuantWeightNZ`](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/API/aolapi/context/aclnnGroupedMatmulSwigluQuantWeightNZ.md).
Its input parameters `weight` and `weight_scale` now accept
`list[torch.Tensor]` (i.e., `at::TensorList`).

### Does this PR introduce _any_ user-facing change?

No.


- vLLM version: v0.11.2

---------

Signed-off-by: QianChenxi <chenxi.qian.cq@outlook.com>
This commit is contained in:
Chenxi Qian
2025-11-28 18:06:39 +08:00
committed by GitHub
parent 3199fe8350
commit 554f16ae1f
50 changed files with 6934 additions and 7 deletions

View File

@@ -552,6 +552,41 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(
output_offset);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weight_nz_tensor_list(
const at::Tensor & x,
const at::TensorList & weight,
const at::TensorList & weight_scale,
const at::Tensor & x_scale,
const at::Tensor & group_list,
const c10::optional<at::Tensor> & bias,
const c10::optional<at::Tensor> & offset)
{
auto x_size = x.sizes();
int n = weight[0].sizes()[1];
int m = x_size[0];
int k = x_size[1];
at::Tensor output = at::zeros({m, n/2}, x.options().dtype(at::kChar));
at::Tensor output_scale = at::zeros({m}, x.options().dtype(at::kFloat));
at::Tensor output_offset = at::zeros({m}, x.options().dtype(at::kFloat));
EXEC_NPU_CMD(
aclnnGroupedMatmulSwigluQuantWeightNzTensorList,
x,
weight,
bias,
offset,
weight_scale,
x_scale,
group_list,
output,
output_scale,
output_offset);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
@@ -614,4 +649,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" Tensor group_list, *, Tensor? bias=None,"
" Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)");
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
ops.def(
"grouped_matmul_swiglu_quant_weight_nz_tensor_list(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale,"
" Tensor group_list, *,"
" Tensor? bias=None, Tensor? offset=None) ->"
" (Tensor output, Tensor output_scale, Tensor output_offset)"
);
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list);
}