Fix FP4 inference corruption issue in glm4.5-air model (#9346)
This commit is contained in:
@@ -205,9 +205,15 @@ def scaled_fp4_quant(
|
||||
rounded_m = ((m + 128 - 1) // 128) * 128
|
||||
scale_n = n // block_size
|
||||
rounded_n = ((scale_n + 4 - 1) // 4) * 4
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
# padded part should be zeroed out
|
||||
if rounded_n > scale_n:
|
||||
output_scale = torch.zeros(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.scaled_fp4_quant.default(
|
||||
output, input, output_scale, input_global_scale
|
||||
@@ -338,12 +344,21 @@ def scaled_fp4_experts_quant(
|
||||
output = torch.empty(
|
||||
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
|
||||
)
|
||||
output_scales = torch.empty(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
# padded part should be zeroed out
|
||||
if padded_k > scales_k:
|
||||
output_scales = torch.zeros(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
else:
|
||||
output_scales = torch.empty(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
|
||||
output,
|
||||
output_scales,
|
||||
|
||||
Reference in New Issue
Block a user