[1/2] Optimizations and refactors about quant kernel (#9534)

This commit is contained in:
fzyzcjy
2025-09-05 18:45:08 +08:00
committed by GitHub
parent afd9f2f560
commit 339f8eef09
11 changed files with 996 additions and 329 deletions

View File

@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return output
def sgl_per_token_group_quant_fp8(
def sgl_per_token_group_quant_8bit(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8(
eps: float,
fp8_min: float,
fp8_max: float,
scale_ue8m0: bool,
scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
def sgl_per_token_group_quant_int8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
input,
output_q,
output_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)