Files
xc-llm-ascend/tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py
Slightwind 9fdabb7b60 [feature] Add Custom Op grouped_matmul_swiglu_quant (#4431)
This PR introduces the `EXEC_NPU_CMD` macro, serving as an adapter layer
to simplify the invocation of `aclnn` operators on Ascend NPUs.

**Key Changes:**
* **Adapter Layer:** Added `EXEC_NPU_CMD` macro and related dependencies
to standardize `aclnn` calls.
* **Operator Support:** Integrated `grouped_matmul_swiglu_quant` as a
reference implementation to demonstrate the usage of the new macro.

---


- vLLM version: v0.11.2

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2025-11-27 21:56:18 +08:00

176 lines
7.5 KiB
Python

import gc
import numpy as np
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
def x_int8_to_x_int4(x: torch.Tensor):
m, k = x.shape
x_high_4bit = torch.floor(x.to(torch.float16) // 16).to(torch.int8)
x_low_4bit = (
torch.bitwise_and(x.view(torch.int16), 0x0f0f).view(torch.int8) - 8)
x_int4 = torch.empty((2 * m, k), dtype=torch.int8)
x_int4[::2, :] = x_high_4bit
x_int4[1::2, :] = x_low_4bit
return x_int4
def custom_mm(x: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, m: int):
"""
Performing Quantized GMM (General Matrix Multiplication) Operation
Parameters:
x (torch.Tensor): Input tensor with shape (m, k).
weight (torch.Tensor): Weight tensor with shape (k, n).
weight_scale (torch.Tensor): Scaling factor for each channel.
- In perGroup scenario: Shape is (k_group_num, n). Note: When k_group_num == 1, it is a perChannel scenario.
- In perChannel scenario: Shape is (n).
m (int): Number of tokens (number of rows in x).
Returns:
mm_out(fp16): Result of MatMul + perGroup or perChannel dequantization.
"""
# Perform matrix multiplication with int32 precision
k, n = weight.shape
mm_out = torch.zeros((m, n), dtype=torch.float16)
# perGroup scenario
if len(weight_scale.shape) == 2 and weight_scale.shape[0] != 1:
k_group = weight_scale.shape[0]
per_group_ele = k // k_group
x_grouped = x.view(-1, k_group, per_group_ele).transpose(0, 1)
weight_grouped = weight.view(k_group, per_group_ele, n)
c_temp = torch.bmm(x_grouped.to(torch.int32),
weight_grouped.to(torch.int32)).to(torch.float16)
for k_idx in range(k_group):
mm_out += (c_temp[k_idx] *
weight_scale[k_idx].view(1, -1).to(torch.float16)).to(
torch.float16)
# perChannel scenario
elif len(weight_scale.shape) == 1 or (len(weight_scale.shape) == 2
and weight_scale.shape[0] == 1):
c_temp = torch.matmul(x.to(torch.int32),
weight.to(torch.int32)).to(torch.float32)
mm_out = c_temp * weight_scale.view(1, -1).to(torch.float16)
return mm_out.to(torch.float32)
def gmm_swiglu_quant_golden_a8_w4(x: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor,
per_token_scale: torch.Tensor,
bias: torch.Tensor,
group_list: torch.Tensor):
"""
Process the input data by group and call the GMM_Swiglu_quant function for quantization computation.
Parameters:
x (torch.Tensor): Input tensor with shape (M, K), type INT8.
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N), data type INT8 but data range INT4, representing INT4 values.
weight_scale (torch.Tensor): Scaling factor for each channel.
- In perGroup scenario: shape (E, k_group_num, N).
- In perChannel scenario: shape (E, N).
per_token_scale (torch.Tensor): Scaling factor for each token, shape (M, ).
bias: torch.Tensor,
group_list (list): List defining the number of tokens in each group.
Returns:
quant_output (torch.Tensor): Quantized output tensor with shape (M, N // 2).
quant_scale_output (torch.Tensor): Quantization scaling factor, shape (M, ).
"""
M, N = x.shape[0], weight.shape[2]
quant_output = torch.zeros(M, N // 2).to(torch.int8)
quant_scale_output = torch.zeros(M).to(torch.float32)
# Preprocessing X_INT8 -> X_INT4
x_int4 = x_int8_to_x_int4(x)
start_idx = 0
# Number of tokens in the previous group
pre_v = 0
group_list = group_list.tolist()
# Traverse group_list and process data by group
for i, v in enumerate(group_list):
curr_v = v
# Calculate the number of tokens in the current group " * 2 " because 1 row of Int8--> 2 rows of Int4
temp_v = int((curr_v - pre_v) * 2)
# Update the number of tokens in the previous group
pre_v = curr_v
if (temp_v > 0):
mm_out = custom_mm(x_int4[int(start_idx):int(start_idx + temp_v)],
weight[i], weight_scale[i], temp_v)
mm_num_concat = ((mm_out[::2] * 16 + mm_out[1::2]) +
bias[i].view(1, -1))
per_token_quant = mm_num_concat * per_token_scale[start_idx // 2:(
start_idx + temp_v) // 2].view(-1, 1)
swiglu, gate = per_token_quant.chunk(2, dim=-1)
temp = swiglu * torch.sigmoid(swiglu)
temp = temp * gate
max_value = torch.max(torch.abs(temp), dim=-1).values
quant_scale_output_temp = 127 / max_value
quant_output[start_idx // 2:(start_idx + temp_v) //
2] = torch.round(temp *
quant_scale_output_temp.reshape(
temp_v // 2, 1)).to(torch.int8)
quant_scale_output[start_idx // 2:(start_idx + temp_v) //
2] = 1 / quant_scale_output_temp
start_idx += temp_v
return quant_output, quant_scale_output
def generate_non_decreasing_sequence(length, upper_limit):
# Generate random increasing sequence
random_increments = torch.randint(0, 128, (length, ))
sequence = torch.cumsum(random_increments, dim=0)
# Make sure the last value is less than the upper limit
if sequence[-1] >= upper_limit:
scale_factor = upper_limit / sequence[-1]
sequence = (sequence * scale_factor).to(torch.int64)
return sequence
@torch.inference_mode()
def test_grouped_matmul_swiglu_quant_kernel():
E = 16
M = 512
K = 7168
N = 4096
torch.npu.config.allow_internal_format = True
x = torch.randint(-5, 5, (M, K), dtype=torch.int8).npu()
weight_ori = torch.randint(-5, 5, (E, K, N), dtype=torch.int8)
weight_nz = torch_npu.npu_format_cast(weight_ori.npu().to(torch.float32),
29)
pack_weight = torch_npu.npu_quantize(weight_nz,
torch.tensor([1.], device='npu'),
None, torch.quint4x2, -1, False)
weight_scale = torch.randn(E, 1, N)
scale_np = weight_scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
pertoken_scale = torch.randn(M).to(torch.float32).npu()
group_list = generate_non_decreasing_sequence(E, M).npu()
bias = torch.zeros((E, N), dtype=torch.float32,
device="npu").uniform_(-5, 5)
output_golden, output_scale_golden = gmm_swiglu_quant_golden_a8_w4(
x.cpu(), weight_ori, weight_scale, pertoken_scale.cpu(), bias.cpu(),
group_list.cpu())
output, output_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant(
x=x,
weight=pack_weight,
bias=bias,
group_list=group_list,
weight_scale=scale_uint64_tensor,
x_scale=pertoken_scale)
torch.testing.assert_close(output_golden, output.cpu(), atol=1, rtol=0.005)
torch.testing.assert_close(output_scale_golden,
output_scale.cpu(),
atol=1,
rtol=0.005)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()