[3/n] chore: decouple AWQ implementation from vLLM dependency (#8113)
Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com>
This commit is contained in:
@@ -11,7 +11,7 @@ import numpy
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -247,6 +247,36 @@ def get_pack_factor(num_bits):
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def permute_rows(
|
||||
q_w: torch.Tensor,
|
||||
w_ref: torch.Tensor,
|
||||
group_size: int,
|
||||
test_perm: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert q_w.shape == w_ref.shape
|
||||
|
||||
orig_device = q_w.device
|
||||
k_size, _ = q_w.shape
|
||||
|
||||
g_idx = torch.zeros((k_size,), dtype=torch.int32)
|
||||
for i in range(k_size):
|
||||
g_idx[i] = i // group_size
|
||||
|
||||
# Simulate act_order by doing a random permutation on K
|
||||
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
||||
|
||||
g_idx = g_idx[rand_perm].contiguous()
|
||||
q_w = q_w[rand_perm, :].contiguous()
|
||||
w_ref = w_ref[rand_perm, :].contiguous()
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
g_idx.to(device=orig_device),
|
||||
rand_perm.to(device=orig_device),
|
||||
)
|
||||
|
||||
|
||||
def pack_cols(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
@@ -399,3 +429,56 @@ def quantize_weights(
|
||||
w_s if group_size is not None else None,
|
||||
maybe_w_zp,
|
||||
)
|
||||
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def gptq_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: Optional[torch.Tensor] = None,
|
||||
):
|
||||
size_k, _ = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert (
|
||||
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
|
||||
), f"Unsupported gptq type = {quant_type}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
||||
|
||||
# Apply act_order
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
if act_order:
|
||||
assert (
|
||||
group_size < size_k
|
||||
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||
group_size, size_k
|
||||
)
|
||||
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
|
||||
|
||||
return w_ref, w_q, w_s, g_idx, rand_perm
|
||||
|
||||
|
||||
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||
orig_device = q_w.device
|
||||
|
||||
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
|
||||
|
||||
g_idx = g_idx[sort_indices].contiguous()
|
||||
q_w = q_w[sort_indices, :].contiguous()
|
||||
|
||||
return (
|
||||
q_w.to(device=orig_device),
|
||||
g_idx.to(device=orig_device),
|
||||
sort_indices.to(device=orig_device),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user