This commit is contained in:
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
|
||||
return output
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_8bit(
|
||||
def sgl_per_token_group_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
@@ -106,21 +106,24 @@ def sgl_per_token_group_quant_8bit(
|
||||
eps: float,
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
scale_ue8m0: bool = False,
|
||||
fuse_silu_and_mul: bool = False,
|
||||
masked_m: Optional[torch.Tensor] = None,
|
||||
scale_ue8m0: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
|
||||
input,
|
||||
output_q,
|
||||
output_s,
|
||||
group_size,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
scale_ue8m0,
|
||||
fuse_silu_and_mul,
|
||||
masked_m,
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_int8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
int8_min: float,
|
||||
int8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
|
||||
input, output_q, output_s, group_size, eps, int8_min, int8_max
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user