[Triton]support swiglu_quant triton in w4a8 (#5161)

### What this PR does / why we need it?
support swiglu_quant triton in w4a8
### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2025-12-22 16:01:58 +08:00
committed by GitHub
parent 60d9398f6d
commit 12d581605b
3 changed files with 132 additions and 3 deletions

View File

@@ -20,6 +20,7 @@ import torch
import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
@@ -243,9 +244,17 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
if HAS_TRITON:
from vllm_ascend.ops.triton.activation.swiglu_quant import \
swiglu_quant
hidden_states, swiglu_out_scale = swiglu_quant(
hidden_states,
group_list=group_list,
group_list_type=group_list_type)
else:
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],

View File

@@ -0,0 +1,120 @@
import torch
from vllm.triton_utils import HAS_TRITON, tl, triton
if HAS_TRITON:
import torch_npu._inductor # noqa: F401
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit
def _swiglu_quant_kernel(
x_ptr,
group_list_ptr,
out_ptr,
scale_ptr,
TOTAL_COLS: tl.constexpr,
HALF_COLS: tl.constexpr,
COL_BLOCK_SIZE: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
NUM_EXPERTS_ALGIN: tl.constexpr,
GROUP_LIST_TYPE: tl.constexpr,
NUM_CORES: tl.constexpr,
DTYPE_MAX: tl.constexpr,
SCALE: tl.constexpr,
):
# calc real total_rows
if GROUP_LIST_TYPE == 0: # cusum
total_rows = tl.load(group_list_ptr + NUM_EXPERTS).to(tl.int32)
else:
gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN)
gl_mask = gl_offsets < NUM_EXPERTS
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask,
other=0).to(tl.int32)
total_rows = tl.sum(group_list)
block_size = (total_rows - 1) // NUM_CORES + 1
pid = tl.program_id(0)
row_begin = pid * block_size
if row_begin >= total_rows:
return
row_end = tl.minimum((pid + 1) * block_size, total_rows)
for row_idx in range(row_begin, row_end):
# swiglu
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
cur_x = tl.load(x_ptr + x_offsets)
x1 = tl.extract_slice(cur_x,
offsets=(0, ),
sizes=(HALF_COLS, ),
strides=(1, ))
x2 = tl.extract_slice(cur_x,
offsets=(HALF_COLS, ),
sizes=(HALF_COLS, ),
strides=(1, ))
out = x1 * tl.sigmoid(x1) * x2
# quant
if SCALE:
scale = tl.max(tl.abs(out)).to(tl.float32) / DTYPE_MAX
# store scale
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
tmp_out = tl.extract_slice(out,
offsets=(col_blk_idx, ),
sizes=(COL_BLOCK_SIZE, ),
strides=(1, ))
tmp_out = (tmp_out.to(tl.float32) / scale).to(
x_ptr.dtype.element_ty)
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")
o_offsets = (row_idx * HALF_COLS + col_blk_idx +
tl.arange(0, COL_BLOCK_SIZE))
mask = (col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)) < HALF_COLS
tl.store(out_ptr + o_offsets,
tmp_out.to(out_ptr.dtype.element_ty),
mask=mask)
else:
# store out
o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS)
tl.store(out_ptr + o_offsets, out.to(out_ptr.dtype.element_ty))
def swiglu_quant(x, group_list, group_list_type, need_quant=True):
# group_list_type must be 0 cusum or 1 count
if group_list_type not in [0, 1]:
raise ValueError(
f"group_list_type must be 0 or 1, but got {group_list_type}")
s, h = x.shape
out_dtype = torch.int8 if need_quant else x.dtype
out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device)
scale = torch.empty((s, ), dtype=torch.float32, device=x.device)
num_experts = group_list.shape[0]
# ub must be 32-byte aligned on npu
if group_list.dtype == torch.int64:
num_experts_algin = (num_experts + 7) // 8 * 8
elif group_list.dtype == torch.int32:
num_experts_algin = (num_experts + 15) // 16 * 16
else:
raise ValueError(
f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}"
)
num_vectorcore = get_vectorcore_num()
_swiglu_quant_kernel[(num_vectorcore, )](
x,
group_list,
out,
scale,
TOTAL_COLS=h,
HALF_COLS=h // 2,
COL_BLOCK_SIZE=1536,
NUM_EXPERTS=num_experts,
NUM_EXPERTS_ALGIN=num_experts_algin,
GROUP_LIST_TYPE=group_list_type,
NUM_CORES=num_vectorcore,
DTYPE_MAX=127,
SCALE=need_quant,
multibuffer=True,
)
return out, scale