Tiny add sanity checks for DeepGEMM inputs (#7157)
This commit is contained in:
@@ -239,6 +239,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
)
|
||||
|
||||
if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
||||
_check_ue8m0("x_scale", x_scale)
|
||||
_check_ue8m0("weight_scale", weight_scale)
|
||||
|
||||
output = w8a8_block_fp8_matmul_deepgemm(
|
||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
||||
)
|
||||
@@ -247,6 +252,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
||||
return output.to(dtype=output_dtype).view(*output_shape)
|
||||
|
||||
|
||||
def _check_ue8m0(name, x):
|
||||
x_ceil = ceil_to_ue8m0(x)
|
||||
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
||||
|
||||
|
||||
def aiter_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@@ -380,6 +390,11 @@ def block_quant_dequant(
|
||||
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
|
||||
|
||||
|
||||
# COPIED FROM DeepGEMM
|
||||
def ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
def channel_quant_to_tensor_quant(
|
||||
x_q_channel: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user