Revert "[1/2] Optimizations and refactors about quant kernel (#9534)" (#10292)

This commit is contained in:
Yineng Zhang
2025-09-10 18:24:23 -07:00
committed by GitHub
parent 033b75f559
commit 6d55f60e77
11 changed files with 328 additions and 995 deletions

View File

@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return output
def sgl_per_token_group_quant_8bit(
def sgl_per_token_group_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
@@ -106,21 +106,24 @@ def sgl_per_token_group_quant_8bit(
eps: float,
fp8_min: float,
fp8_max: float,
scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
scale_ue8m0: bool,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
input,
output_q,
output_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
def sgl_per_token_group_quant_int8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
)