Tiny let DeepGEMM scale checks cover more cases (#7182)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,6 +19,8 @@ if ENABLE_JIT_DEEPGEMM:
|
||||
import deep_gemm
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
|
||||
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
||||
|
||||
|
||||
# TODO maybe rename these functions
|
||||
def grouped_gemm_nt_f8f8bf16_masked(
|
||||
@@ -31,6 +34,9 @@ def grouped_gemm_nt_f8f8bf16_masked(
|
||||
_, n, _ = rhs[0].shape
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
||||
|
||||
_sanity_check_input(lhs)
|
||||
_sanity_check_input(rhs)
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(
|
||||
expected_m, n, k, num_groups, kernel_type
|
||||
):
|
||||
@@ -53,6 +59,9 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
||||
num_groups, n, _ = rhs[0].shape
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
||||
|
||||
_sanity_check_input(lhs)
|
||||
_sanity_check_input(rhs)
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
||||
|
||||
@@ -67,6 +76,9 @@ def gemm_nt_f8f8bf16(
|
||||
num_groups = 1
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
||||
|
||||
_sanity_check_input(lhs)
|
||||
_sanity_check_input(rhs)
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||
deep_gemm.fp8_gemm_nt(
|
||||
lhs,
|
||||
@@ -90,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
|
||||
yield
|
||||
finally:
|
||||
deep_gemm.set_num_sms(original_num_sms)
|
||||
|
||||
|
||||
def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
|
||||
if not _SANITY_CHECK:
|
||||
return
|
||||
|
||||
x, x_scale = x_fp8
|
||||
|
||||
if x_scale.dtype == torch.int:
|
||||
return
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
|
||||
|
||||
x_scale_ceil = ceil_to_ue8m0(x_scale)
|
||||
assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
|
||||
|
||||
@@ -248,11 +248,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
|
||||
# NOTE(alcanderian): Useless when scale is packed to int32
|
||||
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
||||
# _check_ue8m0("x_scale", x_scale)
|
||||
# _check_ue8m0("weight_scale", ws)
|
||||
|
||||
output = w8a8_block_fp8_matmul_deepgemm(
|
||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
||||
)
|
||||
@@ -261,11 +256,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
||||
return output.to(dtype=output_dtype).view(*output_shape)
|
||||
|
||||
|
||||
def _check_ue8m0(name, x):
|
||||
x_ceil = ceil_to_ue8m0(x)
|
||||
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
||||
|
||||
|
||||
def aiter_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user