From 70bb066ee49ff28774c4debf969d7b9786c9ca8d Mon Sep 17 00:00:00 2001 From: Azure <50126533+Azure-Tang@users.noreply.github.com> Date: Thu, 21 Aug 2025 13:13:47 +0800 Subject: [PATCH] Fix FP4 inference corruption issue in glm4.5-air model (#9346) --- sgl-kernel/python/sgl_kernel/gemm.py | 33 ++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 642bd7015..dafc739a1 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -205,9 +205,15 @@ def scaled_fp4_quant( rounded_m = ((m + 128 - 1) // 128) * 128 scale_n = n // block_size rounded_n = ((scale_n + 4 - 1) // 4) * 4 - output_scale = torch.empty( - (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 - ) + # padded part should be zeroed out + if rounded_n > scale_n: + output_scale = torch.zeros( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + else: + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) torch.ops.sgl_kernel.scaled_fp4_quant.default( output, input, output_scale, input_global_scale @@ -338,12 +344,21 @@ def scaled_fp4_experts_quant( output = torch.empty( m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 ) - output_scales = torch.empty( - MAX_TOKENS_PER_EXPERT * topk, - padded_k, - dtype=torch.int32, - device=input_tensor.device, - ) + # padded part should be zeroed out + if padded_k > scales_k: + output_scales = torch.zeros( + MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device, + ) + else: + output_scales = torch.empty( + MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device, + ) torch.ops.sgl_kernel.scaled_fp4_experts_quant.default( output, output_scales,