Tiny add sanity checks for DeepGEMM inputs (#7157)
This commit is contained in:
@@ -239,6 +239,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|||||||
column_major_scales=True,
|
column_major_scales=True,
|
||||||
scale_tma_aligned=True,
|
scale_tma_aligned=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
||||||
|
_check_ue8m0("x_scale", x_scale)
|
||||||
|
_check_ue8m0("weight_scale", weight_scale)
|
||||||
|
|
||||||
output = w8a8_block_fp8_matmul_deepgemm(
|
output = w8a8_block_fp8_matmul_deepgemm(
|
||||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
||||||
)
|
)
|
||||||
@@ -247,6 +252,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|||||||
return output.to(dtype=output_dtype).view(*output_shape)
|
return output.to(dtype=output_dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_ue8m0(name, x):
|
||||||
|
x_ceil = ceil_to_ue8m0(x)
|
||||||
|
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
||||||
|
|
||||||
|
|
||||||
def aiter_w8a8_block_fp8_linear(
|
def aiter_w8a8_block_fp8_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
@@ -380,6 +390,11 @@ def block_quant_dequant(
|
|||||||
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
|
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# COPIED FROM DeepGEMM
|
||||||
|
def ceil_to_ue8m0(x: torch.Tensor):
|
||||||
|
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||||
|
|
||||||
|
|
||||||
def channel_quant_to_tensor_quant(
|
def channel_quant_to_tensor_quant(
|
||||||
x_q_channel: torch.Tensor,
|
x_q_channel: torch.Tensor,
|
||||||
x_s: torch.Tensor,
|
x_s: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user