Tiny add sanity checks for DeepGEMM inputs (#7157)

This commit is contained in:
fzyzcjy
2025-06-14 05:36:27 +08:00
committed by GitHub
parent e3ec6bf4b6
commit 0f1dfa1efe

View File

@@ -239,6 +239,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
column_major_scales=True,
scale_tma_aligned=True,
)
if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
_check_ue8m0("x_scale", x_scale)
_check_ue8m0("weight_scale", weight_scale)
output = w8a8_block_fp8_matmul_deepgemm(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
)
@@ -247,6 +252,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
return output.to(dtype=output_dtype).view(*output_shape)
def _check_ue8m0(name, x):
x_ceil = ceil_to_ue8m0(x)
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
def aiter_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
@@ -380,6 +390,11 @@ def block_quant_dequant(
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
# COPIED FROM DeepGEMM
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def channel_quant_to_tensor_quant(
x_q_channel: torch.Tensor,
x_s: torch.Tensor,