[dev] support AWQ/GPTQ quantization for dense models

This commit is contained in:
Li Wei
2025-12-24 13:45:55 +08:00
parent 75d0bdae2f
commit 6546323c71
5 changed files with 412 additions and 2 deletions

View File

@@ -1149,3 +1149,175 @@ def fake_moe_post(
return None
moe_post.register_fake(fake_moe_post)
##################################################
# --------------- awq_dequantize -----------------
##################################################
@custom_op("_C::awq_dequantize", mutates_args=())
def awq_dequantize(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
xtorch_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
weight=weight,
group_m=group_m,
quant_type=quant_type,
align_type=align_type,
)
return weight
@impl("_C::awq_dequantize", "CUDA")
def awq_dequantize_cuda(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
out = xtorch_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
weight=weight,
group_m=group_m,
quant_type=quant_type,
align_type=align_type,
)
return weight
def _fake_awq_dequantize(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
return weight
awq_dequantize.register_fake(_fake_awq_dequantize)
##################################################
# ------------------ awq_gemm -------------------
##################################################
@custom_op("_C::awq_gemm", mutates_args=())
def awq_gemm(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
zeros=zeros,
out=out,
align_type=align_type,
group_size=group_size,
)
return out
@impl("_C::awq_gemm", "CUDA")
def awq_gemm_cuda(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
zeros=zeros,
out=out,
align_type=align_type,
group_size=group_size,
)
return out
def _fake_awq_gemm(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
return out
awq_gemm.register_fake(_fake_awq_gemm)
##################################################
# ---------------- gptq_shuffle ------------------
##################################################
@custom_op("_C::gptq_shuffle", mutates_args=())
def gptq_shuffle(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
@impl("_C::gptq_shuffle", "CUDA")
def gptq_shuffle_cuda(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
def _fake_gptq_shuffle(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
return None
gptq_shuffle.register_fake(_fake_gptq_shuffle)