Fix FP4 inference corruption issue in glm4.5-air model (#9346)
This commit is contained in:
@@ -205,6 +205,12 @@ def scaled_fp4_quant(
|
|||||||
rounded_m = ((m + 128 - 1) // 128) * 128
|
rounded_m = ((m + 128 - 1) // 128) * 128
|
||||||
scale_n = n // block_size
|
scale_n = n // block_size
|
||||||
rounded_n = ((scale_n + 4 - 1) // 4) * 4
|
rounded_n = ((scale_n + 4 - 1) // 4) * 4
|
||||||
|
# padded part should be zeroed out
|
||||||
|
if rounded_n > scale_n:
|
||||||
|
output_scale = torch.zeros(
|
||||||
|
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
else:
|
||||||
output_scale = torch.empty(
|
output_scale = torch.empty(
|
||||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
@@ -338,6 +344,15 @@ def scaled_fp4_experts_quant(
|
|||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
|
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
|
||||||
)
|
)
|
||||||
|
# padded part should be zeroed out
|
||||||
|
if padded_k > scales_k:
|
||||||
|
output_scales = torch.zeros(
|
||||||
|
MAX_TOKENS_PER_EXPERT * topk,
|
||||||
|
padded_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_tensor.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
output_scales = torch.empty(
|
output_scales = torch.empty(
|
||||||
MAX_TOKENS_PER_EXPERT * topk,
|
MAX_TOKENS_PER_EXPERT * topk,
|
||||||
padded_k,
|
padded_k,
|
||||||
|
|||||||
Reference in New Issue
Block a user