Files
sglang/sgl-kernel/python/sgl_kernel/marlin.py
AniZpZ 8e03b641ba [1/n] apply wna16marlin kernel in moe weight only quantization (#7683)
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com>
Co-authored-by: yych0745 <1398089567@qq.com>
Co-authored-by: HandH1998 <1335248067@qq.com>
Co-authored-by: 弋云 <yiyun.wyt@antgroup.com>
Co-authored-by: walker-ai <2398833647@qq.com>
2025-07-01 23:21:25 -07:00

45 lines
1006 B
Python

import torch
def gptq_marlin_repack(
b_q_weight,
perm,
size_k,
size_n,
num_bits,
):
torch.ops.sgl_kernel.gptq_marlin_repack.default(
b_q_weight,
perm,
size_k,
size_n,
num_bits,
)
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
def awq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits
)
return output