diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 907177d6b..4c99fe702 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -265,6 +265,8 @@ from sgl_kernel.gemm import ( scaled_fp4_quant, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_8bit, + sgl_per_token_group_quant_fp8, + sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, shuffle_rows, silu_and_mul_scaled_fp4_grouped_quant, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index f93aca19f..5cfdd412b 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -137,6 +137,11 @@ def sgl_per_token_group_quant_8bit( ) +# For legacy usage +sgl_per_token_group_quant_fp8 = sgl_per_token_group_quant_8bit +sgl_per_token_group_quant_int8 = sgl_per_token_group_quant_8bit + + def sgl_per_tensor_quant_fp8( input: torch.Tensor, output_q: torch.Tensor,