[Kernel] add custom op GmmSwigluQuantWeightNzTensorList (#3804)

### What this PR does / why we need it?

This PR introduces support for adding custom CANN `aclnn` ops to
`vllm-ascend`, allowing users to define and use their own custom
operators.

Key changes include:
- Building and installing custom ops into the `vllm-ascend`-specified
directory
- Binding the `aclnn` op interface to the `torch.ops._C_ascend` module
- Enabling invocation of these ops within `vllm-ascend`

This PR includes a sample custom op:
`aclnnGroupedMatmulSwigluQuantWeightNzTensorList`, which is adapted from
the CANN operator
[`aclnnGroupedMatmulSwigluQuantWeightNZ`](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/API/aolapi/context/aclnnGroupedMatmulSwigluQuantWeightNZ.md).
Its input parameters `weight` and `weight_scale` now accept
`list[torch.Tensor]` (i.e., `at::TensorList`).

### Does this PR introduce _any_ user-facing change?

No.


- vLLM version: v0.11.2

---------

Signed-off-by: QianChenxi <chenxi.qian.cq@outlook.com>
This commit is contained in:
Chenxi Qian
2025-11-28 18:06:39 +08:00
committed by GitHub
parent 3199fe8350
commit 554f16ae1f
50 changed files with 6934 additions and 7 deletions

View File

@@ -0,0 +1,148 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
# enable internal format
torch_npu.npu.config.allow_internal_format = True
# enable vllm-ascend custom ops
enable_custom_op()
def gmm_swiglu_quant(x: torch.Tensor, weight: torch.Tensor,
perChannelScale: torch.Tensor,
perTokenScale: torch.Tensor, m: int):
"""
Perform quantized GMM (Grouped Matrix Multiplication) operation with SwiGLU activation function.
Parameters:
x (torch.Tensor): Input tensor with shape (m, k).
weight (torch.Tensor): Weight tensor with shape (k, n).
perChannelScale (torch.Tensor): Per-channel scaling factor with shape (n,).
perTokenScale (torch.Tensor): Per-token scaling factor with shape (m,).
m (int): Number of tokens (rows of x).
Returns:
quantOutput (torch.Tensor): Quantized output tensor with shape (m, k // 2).
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (m,).
"""
# Perform matrix multiplication with int32 precision
c_temp1 = torch.matmul(x.to(torch.int32), weight.to(torch.int32))
c_temp1 = c_temp1.to(torch.float32) # Convert back to float32 for scaling
# Apply per-channel and per-token scaling
c_temp2 = torch.mul(c_temp1, perChannelScale)
c_temp3 = torch.mul(c_temp2, perTokenScale.reshape(m, 1))
# Split the result into two parts to apply SwiGLU activation function
c_temp4, gate = c_temp3.chunk(2, dim=-1)
c_temp5 = c_temp4 * torch.sigmoid(c_temp4) # SwiGLU activation
c_temp6 = c_temp5 * gate # Element-wise multiplication with gating values
# Quantize the output
max = torch.max(
torch.abs(c_temp6),
-1).values # Find maximum absolute value to calculate scaling factor
quantScaleOutput = 127 / max # Calculate quantization scaling factor
quantOutput = torch.round(c_temp6 * quantScaleOutput.reshape(m, 1)).to(
torch.int8) # Quantize to int8
quantScaleOutput = 1 / quantScaleOutput # Inverse quantization scaling factor for subsequent dequantization
return quantOutput, quantScaleOutput
def process_groups(x: torch.Tensor, weight: torch.Tensor,
perChannelScale: torch.Tensor, perTokenScale: torch.Tensor,
groupList: torch.Tensor):
"""
Process input data by groups and call GMM_Swiglu_quant function for quantized computation.
Parameters:
x (torch.Tensor): Input tensor with shape (M, K).
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N).
perChannelScale (torch.Tensor): List of per-channel scaling factors, each with shape (E, N).
perTokenScale (torch.Tensor): Per-token scaling factor with shape (M,).
groupList (list): List defining the number of tokens in each group.
Returns:
quantOutput (torch.Tensor): Quantized output tensor with shape (M, N // 2).
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (M,).
"""
M, N = x.shape[0], weight.shape[2] # Get the shape of the input tensor
quantOutput = torch.zeros(M, N // 2).to(
torch.int8) # Initialize quantized output tensor
quantScaleOutput = torch.zeros(M).to(
torch.float32) # Initialize quantization scaling factor tensor
start_idx = 0 # Starting index
preV = 0 # Number of tokens in the previous group
groupList = groupList.tolist()
# Iterate through groupList to process data by groups
for i, v in enumerate(groupList):
currV = v
tempV = currV - preV # Calculate number of tokens in the current group
preV = currV # Update number of tokens in the previous group
if tempV > 0:
# Call GMM_Swiglu_quant to process the current group
quantOutput[start_idx:start_idx + tempV], quantScaleOutput[start_idx:start_idx + tempV] = \
gmm_swiglu_quant(x[start_idx:start_idx + tempV],
weight[i],
perChannelScale[i],
perTokenScale[start_idx:start_idx + tempV],
tempV)
start_idx += tempV # Update starting index to process the next group
return quantOutput, quantScaleOutput
@torch.inference_mode()
def test_gmm_swiglu_quant_weight_nz_tensor_list():
M, K, E, N = 8192, 7168, 4, 4096
# x (M, K) - int8
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
# weight (E, N, K) - int8
weight = torch.randint(-128, 127, size=(E, K, N), dtype=torch.int8)
# weight_scale (E, N) - float32
weight_scale = torch.rand(E, N) * 0.9 + 0.1 # uniform(0.1, 1.0)
weight_scale = weight_scale.to(torch.float32)
weight_nz_npu = []
weight_scale_npu = []
for i in range(E):
weight_nz_npu.append(torch_npu.npu_format_cast(weight[i].npu(), 29))
weight_scale_npu.append(weight_scale[i].npu())
# x_scale (M,) - float32
x_scale = torch.rand(M) * 0.9 + 0.1 # uniform(0.1, 1.0)
x_scale = x_scale.to(torch.float32)
group_list = torch.tensor([2048, 4096, 6144, 8192], dtype=torch.int64)
output_cpu, output_scale_cpu = process_groups(x, weight, weight_scale,
x_scale, group_list)
output_npu, output_scale_npu, _ = \
torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(x.npu(),
weight_nz_npu,
weight_scale_npu,
x_scale.npu(),
group_list.npu())
output_npu_valid = output_npu[:group_list[-1], :]
output_scale_npu_valid = output_scale_npu[:group_list[-1]]
torch.testing.assert_close(output_npu_valid.cpu(),
output_cpu,
atol=1,
rtol=2**-13)
torch.testing.assert_close(output_scale_npu_valid.cpu(),
output_scale_cpu,
atol=1e-9,
rtol=1e-6)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()