[2/n]decouple quantization implementation from vLLM dependency (#8112)
Co-authored-by: walker-ai <yiyun.wyt@antgroup.com> Co-authored-by: leoneo <1320612015@qq.com>
This commit is contained in:
@@ -44,6 +44,9 @@ from sgl_kernel.gemm import (
|
||||
dsv3_router_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
gptq_gemm,
|
||||
gptq_marlin_gemm,
|
||||
gptq_shuffle,
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
|
||||
@@ -2,6 +2,7 @@ import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
@@ -165,7 +166,7 @@ def fused_marlin_moe(
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
||||
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import ScalarType
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
@@ -353,3 +354,62 @@ def scaled_fp4_experts_quant(
|
||||
)
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
# GPTQ kernels
|
||||
def gptq_marlin_gemm(
|
||||
a: torch.Tensor,
|
||||
c: Optional[torch.Tensor],
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
global_scale: Optional[torch.Tensor],
|
||||
b_zeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
is_k_full: bool = True,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_marlin_gemm(
|
||||
a,
|
||||
c,
|
||||
b_q_weight,
|
||||
b_scales,
|
||||
global_scale,
|
||||
b_zeros,
|
||||
g_idx,
|
||||
perm,
|
||||
workspace,
|
||||
b_q_type.id,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
is_zp_float,
|
||||
)
|
||||
|
||||
|
||||
def gptq_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_gptq_qzeros: torch.Tensor,
|
||||
b_gptq_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_shuffle: bool,
|
||||
bit: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
|
||||
)
|
||||
|
||||
|
||||
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
|
||||
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)
|
||||
|
||||
@@ -7,8 +7,8 @@ def gptq_marlin_repack(
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
):
|
||||
torch.ops.sgl_kernel.gptq_marlin_repack.default(
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
|
||||
Reference in New Issue
Block a user