### What this PR does / why we need it?
This fixes a bug that occurred when running `test_camem.py` in the
triton-ascend environment `NPU function error:
aclrtGetMemInfo(ACL_HBM_MEM, &device_free, &device_total)`
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
118 lines
4.2 KiB
Python
118 lines
4.2 KiB
Python
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
|
|
|
|
|
@triton.jit
|
|
def _swiglu_quant_kernel(
|
|
x_ptr,
|
|
group_list_ptr,
|
|
out_ptr,
|
|
scale_ptr,
|
|
TOTAL_COLS: tl.constexpr,
|
|
HALF_COLS: tl.constexpr,
|
|
COL_BLOCK_SIZE: tl.constexpr,
|
|
NUM_EXPERTS: tl.constexpr,
|
|
NUM_EXPERTS_ALGIN: tl.constexpr,
|
|
GROUP_LIST_TYPE: tl.constexpr,
|
|
NUM_CORES: tl.constexpr,
|
|
DTYPE_MAX: tl.constexpr,
|
|
SCALE: tl.constexpr,
|
|
):
|
|
# calc real total_rows
|
|
if GROUP_LIST_TYPE == 0: # cusum
|
|
total_rows = tl.load(group_list_ptr + NUM_EXPERTS).to(tl.int32)
|
|
else:
|
|
gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN)
|
|
gl_mask = gl_offsets < NUM_EXPERTS
|
|
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask,
|
|
other=0).to(tl.int32)
|
|
total_rows = tl.sum(group_list)
|
|
|
|
block_size = (total_rows - 1) // NUM_CORES + 1
|
|
pid = tl.program_id(0)
|
|
row_begin = pid * block_size
|
|
if row_begin >= total_rows:
|
|
return
|
|
row_end = tl.minimum((pid + 1) * block_size, total_rows)
|
|
|
|
for row_idx in range(row_begin, row_end):
|
|
# swiglu
|
|
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
|
|
cur_x = tl.load(x_ptr + x_offsets)
|
|
x1 = tl.extract_slice(cur_x,
|
|
offsets=(0, ),
|
|
sizes=(HALF_COLS, ),
|
|
strides=(1, ))
|
|
x2 = tl.extract_slice(cur_x,
|
|
offsets=(HALF_COLS, ),
|
|
sizes=(HALF_COLS, ),
|
|
strides=(1, ))
|
|
out = x1 * tl.sigmoid(x1) * x2
|
|
|
|
# quant
|
|
if SCALE:
|
|
scale = tl.max(tl.abs(out)).to(tl.float32) / DTYPE_MAX
|
|
# store scale
|
|
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
|
|
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
|
|
tmp_out = tl.extract_slice(out,
|
|
offsets=(col_blk_idx, ),
|
|
sizes=(COL_BLOCK_SIZE, ),
|
|
strides=(1, ))
|
|
tmp_out = (tmp_out.to(tl.float32) / scale).to(
|
|
x_ptr.dtype.element_ty)
|
|
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")
|
|
|
|
o_offsets = (row_idx * HALF_COLS + col_blk_idx +
|
|
tl.arange(0, COL_BLOCK_SIZE))
|
|
mask = (col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)) < HALF_COLS
|
|
tl.store(out_ptr + o_offsets,
|
|
tmp_out.to(out_ptr.dtype.element_ty),
|
|
mask=mask)
|
|
else:
|
|
# store out
|
|
o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS)
|
|
tl.store(out_ptr + o_offsets, out.to(out_ptr.dtype.element_ty))
|
|
|
|
|
|
def swiglu_quant(x, group_list, group_list_type, need_quant=True):
|
|
# group_list_type must be 0 cusum or 1 count
|
|
if group_list_type not in [0, 1]:
|
|
raise ValueError(
|
|
f"group_list_type must be 0 or 1, but got {group_list_type}")
|
|
s, h = x.shape
|
|
out_dtype = torch.int8 if need_quant else x.dtype
|
|
out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device)
|
|
scale = torch.empty((s, ), dtype=torch.float32, device=x.device)
|
|
num_experts = group_list.shape[0]
|
|
# ub must be 32-byte aligned on npu
|
|
if group_list.dtype == torch.int64:
|
|
num_experts_algin = (num_experts + 7) // 8 * 8
|
|
elif group_list.dtype == torch.int32:
|
|
num_experts_algin = (num_experts + 15) // 16 * 16
|
|
else:
|
|
raise ValueError(
|
|
f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}"
|
|
)
|
|
|
|
num_vectorcore = get_vectorcore_num()
|
|
_swiglu_quant_kernel[(num_vectorcore, )](
|
|
x,
|
|
group_list,
|
|
out,
|
|
scale,
|
|
TOTAL_COLS=h,
|
|
HALF_COLS=h // 2,
|
|
COL_BLOCK_SIZE=1536,
|
|
NUM_EXPERTS=num_experts,
|
|
NUM_EXPERTS_ALGIN=num_experts_algin,
|
|
GROUP_LIST_TYPE=group_list_type,
|
|
NUM_CORES=num_vectorcore,
|
|
DTYPE_MAX=127,
|
|
SCALE=need_quant,
|
|
multibuffer=True,
|
|
)
|
|
return out, scale
|