[dev] support AWQ/GPTQ quantization for dense models
This commit is contained in:
@@ -1149,3 +1149,175 @@ def fake_moe_post(
|
||||
return None
|
||||
|
||||
moe_post.register_fake(fake_moe_post)
|
||||
|
||||
|
||||
##################################################
|
||||
# --------------- awq_dequantize -----------------
|
||||
##################################################
|
||||
@custom_op("_C::awq_dequantize", mutates_args=())
|
||||
def awq_dequantize(
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
quant_type: int = 0,
|
||||
align_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
weight = torch.empty(
|
||||
(qweight.shape[0], qweight.shape[1] * 8),
|
||||
dtype=torch.float16,
|
||||
device=qweight.device,
|
||||
)
|
||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||
xtorch_ops.awq_dequantize(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
zeros=zeros,
|
||||
weight=weight,
|
||||
group_m=group_m,
|
||||
quant_type=quant_type,
|
||||
align_type=align_type,
|
||||
)
|
||||
return weight
|
||||
|
||||
|
||||
@impl("_C::awq_dequantize", "CUDA")
|
||||
def awq_dequantize_cuda(
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
quant_type: int = 0,
|
||||
align_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
weight = torch.empty(
|
||||
(qweight.shape[0], qweight.shape[1] * 8),
|
||||
dtype=torch.float16,
|
||||
device=qweight.device,
|
||||
)
|
||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||
out = xtorch_ops.awq_dequantize(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
zeros=zeros,
|
||||
weight=weight,
|
||||
group_m=group_m,
|
||||
quant_type=quant_type,
|
||||
align_type=align_type,
|
||||
)
|
||||
return weight
|
||||
|
||||
|
||||
def _fake_awq_dequantize(
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
quant_type: int = 0,
|
||||
align_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
weight = torch.empty(
|
||||
(qweight.shape[0], qweight.shape[1] * 8),
|
||||
dtype=torch.float16,
|
||||
device=qweight.device,
|
||||
)
|
||||
return weight
|
||||
|
||||
|
||||
awq_dequantize.register_fake(_fake_awq_dequantize)
|
||||
|
||||
|
||||
##################################################
|
||||
# ------------------ awq_gemm -------------------
|
||||
##################################################
|
||||
@custom_op("_C::awq_gemm", mutates_args=())
|
||||
def awq_gemm(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
align_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty(
|
||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||
)
|
||||
group_size = int(qweight.shape[0] / scale.shape[0])
|
||||
xtorch_ops.awq_gemm(
|
||||
x=x,
|
||||
w=qweight,
|
||||
scale=scale,
|
||||
zeros=zeros,
|
||||
out=out,
|
||||
align_type=align_type,
|
||||
group_size=group_size,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@impl("_C::awq_gemm", "CUDA")
|
||||
def awq_gemm_cuda(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
align_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty(
|
||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||
)
|
||||
group_size = int(qweight.shape[0] / scale.shape[0])
|
||||
xtorch_ops.awq_gemm(
|
||||
x=x,
|
||||
w=qweight,
|
||||
scale=scale,
|
||||
zeros=zeros,
|
||||
out=out,
|
||||
align_type=align_type,
|
||||
group_size=group_size,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _fake_awq_gemm(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
align_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty(
|
||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
awq_gemm.register_fake(_fake_awq_gemm)
|
||||
|
||||
|
||||
##################################################
|
||||
# ---------------- gptq_shuffle ------------------
|
||||
##################################################
|
||||
@custom_op("_C::gptq_shuffle", mutates_args=())
|
||||
def gptq_shuffle(
|
||||
q_weight: torch.Tensor,
|
||||
q_perm: torch.Tensor,
|
||||
bit: int,
|
||||
) -> None:
|
||||
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
|
||||
|
||||
@impl("_C::gptq_shuffle", "CUDA")
|
||||
def gptq_shuffle_cuda(
|
||||
q_weight: torch.Tensor,
|
||||
q_perm: torch.Tensor,
|
||||
bit: int,
|
||||
) -> None:
|
||||
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
|
||||
|
||||
def _fake_gptq_shuffle(
|
||||
q_weight: torch.Tensor,
|
||||
q_perm: torch.Tensor,
|
||||
bit: int,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
gptq_shuffle.register_fake(_fake_gptq_shuffle)
|
||||
Reference in New Issue
Block a user