Fix FP8 block quantization when N or K is not multiples of 128 (#8648)

This commit is contained in:
YanbingJiang
2025-08-02 06:57:19 +08:00
committed by GitHub
parent e252192679
commit 1fe691a429
3 changed files with 39 additions and 18 deletions

View File

@@ -136,18 +136,33 @@ def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_facto
def scaled_weight(weight, scales):
E, N, K = weight.shape
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K
if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
weight_block = (
weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K)
weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
.permute(0, 1, 3, 2, 4)
.float()
.contiguous()
)
return (
(weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1))
weight_scaled = (
(
weight_block
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
)
.permute(0, 1, 3, 2, 4)
.contiguous()
.view(E, N, K)
)
if pad_N > 0 or pad_K > 0:
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
weight_scaled = weight_scaled[..., :N, :K].contiguous()
else:
weight_scaled = weight_scaled.view(E, N, K)
return weight_scaled
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):