[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: 弋云 <yiyun.wyt@antgroup.com> Co-authored-by: walker-ai <2398833647@qq.com>
This commit is contained in:
44
sgl-kernel/python/sgl_kernel/marlin.py
Normal file
44
sgl-kernel/python/sgl_kernel/marlin.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
|
||||
|
||||
def gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
):
|
||||
torch.ops.sgl_kernel.gptq_marlin_repack.default(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
)
|
||||
|
||||
|
||||
def awq_marlin_repack(
|
||||
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
def awq_marlin_moe_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = torch.empty(
|
||||
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
||||
device=b_q_weight.device,
|
||||
dtype=b_q_weight.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
|
||||
b_q_weight[e], size_k, size_n, num_bits
|
||||
)
|
||||
return output
|
||||
Reference in New Issue
Block a user