Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: 弋云 <yiyun.wyt@antgroup.com> Co-authored-by: walker-ai <2398833647@qq.com>
45 lines
1006 B
Python
45 lines
1006 B
Python
import torch
|
|
|
|
|
|
def gptq_marlin_repack(
|
|
b_q_weight,
|
|
perm,
|
|
size_k,
|
|
size_n,
|
|
num_bits,
|
|
):
|
|
torch.ops.sgl_kernel.gptq_marlin_repack.default(
|
|
b_q_weight,
|
|
perm,
|
|
size_k,
|
|
size_n,
|
|
num_bits,
|
|
)
|
|
|
|
|
|
def awq_marlin_repack(
|
|
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
|
) -> torch.Tensor:
|
|
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
|
|
|
|
|
def awq_marlin_moe_repack(
|
|
b_q_weight: torch.Tensor,
|
|
perm: torch.Tensor,
|
|
size_k: int,
|
|
size_n: int,
|
|
num_bits: int,
|
|
) -> torch.Tensor:
|
|
num_experts = b_q_weight.shape[0]
|
|
assert size_k % 16 == 0
|
|
output = torch.empty(
|
|
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
|
device=b_q_weight.device,
|
|
dtype=b_q_weight.dtype,
|
|
)
|
|
for e in range(num_experts):
|
|
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
|
|
b_q_weight[e], size_k, size_n, num_bits
|
|
)
|
|
return output
|