Tiny let DeepGEMM scale checks cover more cases (#7182)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
|||||||
ENABLE_JIT_DEEPGEMM,
|
ENABLE_JIT_DEEPGEMM,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -18,6 +19,8 @@ if ENABLE_JIT_DEEPGEMM:
|
|||||||
import deep_gemm
|
import deep_gemm
|
||||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||||
|
|
||||||
|
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
||||||
|
|
||||||
|
|
||||||
# TODO maybe rename these functions
|
# TODO maybe rename these functions
|
||||||
def grouped_gemm_nt_f8f8bf16_masked(
|
def grouped_gemm_nt_f8f8bf16_masked(
|
||||||
@@ -31,6 +34,9 @@ def grouped_gemm_nt_f8f8bf16_masked(
|
|||||||
_, n, _ = rhs[0].shape
|
_, n, _ = rhs[0].shape
|
||||||
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
||||||
|
|
||||||
|
_sanity_check_input(lhs)
|
||||||
|
_sanity_check_input(rhs)
|
||||||
|
|
||||||
with compile_utils.deep_gemm_execution_hook(
|
with compile_utils.deep_gemm_execution_hook(
|
||||||
expected_m, n, k, num_groups, kernel_type
|
expected_m, n, k, num_groups, kernel_type
|
||||||
):
|
):
|
||||||
@@ -53,6 +59,9 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
|||||||
num_groups, n, _ = rhs[0].shape
|
num_groups, n, _ = rhs[0].shape
|
||||||
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
||||||
|
|
||||||
|
_sanity_check_input(lhs)
|
||||||
|
_sanity_check_input(rhs)
|
||||||
|
|
||||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
||||||
|
|
||||||
@@ -67,6 +76,9 @@ def gemm_nt_f8f8bf16(
|
|||||||
num_groups = 1
|
num_groups = 1
|
||||||
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
||||||
|
|
||||||
|
_sanity_check_input(lhs)
|
||||||
|
_sanity_check_input(rhs)
|
||||||
|
|
||||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||||
deep_gemm.fp8_gemm_nt(
|
deep_gemm.fp8_gemm_nt(
|
||||||
lhs,
|
lhs,
|
||||||
@@ -90,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
deep_gemm.set_num_sms(original_num_sms)
|
deep_gemm.set_num_sms(original_num_sms)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
|
||||||
|
if not _SANITY_CHECK:
|
||||||
|
return
|
||||||
|
|
||||||
|
x, x_scale = x_fp8
|
||||||
|
|
||||||
|
if x_scale.dtype == torch.int:
|
||||||
|
return
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
|
||||||
|
|
||||||
|
x_scale_ceil = ceil_to_ue8m0(x_scale)
|
||||||
|
assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
|
||||||
|
|||||||
@@ -248,11 +248,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|||||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(alcanderian): Useless when scale is packed to int32
|
|
||||||
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
|
||||||
# _check_ue8m0("x_scale", x_scale)
|
|
||||||
# _check_ue8m0("weight_scale", ws)
|
|
||||||
|
|
||||||
output = w8a8_block_fp8_matmul_deepgemm(
|
output = w8a8_block_fp8_matmul_deepgemm(
|
||||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
||||||
)
|
)
|
||||||
@@ -261,11 +256,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|||||||
return output.to(dtype=output_dtype).view(*output_shape)
|
return output.to(dtype=output_dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
def _check_ue8m0(name, x):
|
|
||||||
x_ceil = ceil_to_ue8m0(x)
|
|
||||||
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
|
||||||
|
|
||||||
|
|
||||||
def aiter_w8a8_block_fp8_linear(
|
def aiter_w8a8_block_fp8_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user