Reland [1/2] Optimizations and refactors about quant kernel (#10312)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
fzyzcjy
2025-10-11 15:59:03 +08:00
committed by GitHub
parent 129d299278
commit 21337b22b9
13 changed files with 1065 additions and 178 deletions

View File

@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return output
def sgl_per_token_group_quant_fp8(
def sgl_per_token_group_quant_8bit(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
@@ -106,27 +106,37 @@ def sgl_per_token_group_quant_fp8(
eps: float,
fp8_min: float,
fp8_max: float,
scale_ue8m0: bool,
scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
enable_v2: Optional[bool] = None,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
if enable_v2 is None:
from sglang.srt.utils import get_bool_env_var
enable_v2 = get_bool_env_var("SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2")
if enable_v2:
return torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit_v2.default(
input,
output_q,
output_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)
assert not fuse_silu_and_mul, "only v2 support fuse_silu_and_mul"
assert masked_m is None, "only v2 support masked_m"
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
def sgl_per_token_group_quant_int8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
)
def sgl_per_tensor_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,