[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: 弋云 <yiyun.wyt@antgroup.com> Co-authored-by: walker-ai <2398833647@qq.com>
This commit is contained in:
166
python/sglang/srt/layers/quantization/quant_utils.py
Normal file
166
python/sglang/srt/layers/quantization/quant_utils.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import ScalarType
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def pack_cols(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def unpack_cols(
|
||||
packed_q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
assert packed_q_w.shape == (
|
||||
size_k,
|
||||
size_n // pack_factor,
|
||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor
|
||||
)
|
||||
|
||||
orig_device = packed_q_w.device
|
||||
|
||||
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
||||
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
for i in range(pack_factor):
|
||||
vals = packed_q_w_cpu & mask
|
||||
packed_q_w_cpu >>= num_bits
|
||||
q_res[:, i::pack_factor] = vals
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def quantize_weights(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False,
|
||||
):
|
||||
assert (
|
||||
quant_type.is_integer()
|
||||
), "Floating point quantization may work but has not been tested"
|
||||
assert not zero_points or group_size is not None, (
|
||||
"to have group zero points, group_size must be provided "
|
||||
"(-1 group_size is channelwise)"
|
||||
)
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size is not None and group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
max_val = torch.max(w, 0, keepdim=True).values
|
||||
min_val = torch.min(w, 0, keepdim=True).values
|
||||
|
||||
max_q_val = quant_type.max()
|
||||
min_q_val = quant_type.min()
|
||||
|
||||
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
||||
maybe_w_zp = None
|
||||
if group_size is not None:
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = (
|
||||
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
||||
)
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
||||
)
|
||||
|
||||
# Quantize
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
# For some kernels (namely Machete) the zero-points are applied after the
|
||||
# scales are applied, for this case computing the reference in similar way
|
||||
# allows us to use tighter error tolerances in our unit tests.
|
||||
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
||||
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
||||
else:
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
|
||||
if quant_type.has_bias():
|
||||
w_q += quant_type.bias
|
||||
|
||||
# Restore original shapes
|
||||
if group_size is not None and group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
w_q = reshape_w(w_q)
|
||||
w_ref = reshape_w(w_ref)
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
if maybe_w_zp is not None:
|
||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
w_q.to(device=orig_device),
|
||||
w_s if group_size is not None else None,
|
||||
maybe_w_zp,
|
||||
)
|
||||
Reference in New Issue
Block a user