Fix FP8 block quantization when N or K is not multiples of 128 (#8648)

This commit is contained in:
YanbingJiang
2025-08-02 06:57:19 +08:00
committed by GitHub
parent e252192679
commit 1fe691a429
3 changed files with 39 additions and 18 deletions

View File

@@ -75,8 +75,8 @@ class TestFusedExperts(CustomTestCase):
topk_int8 = [3]
M_fp8 = [2, 121]
N_fp8 = [512]
K_fp8 = [256]
N_fp8 = [352, 512]
K_fp8 = [256, 320]
E_fp8 = [8]
topk_fp8 = [4]
@@ -201,8 +201,14 @@ class TestFusedExperts(CustomTestCase):
w2_fp32 = torch.randn(E, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
w1s = (
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
* factor_for_scale
)
w2s = (
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
* factor_for_scale
)
w1_scaled = scaled_weight(w1, w1s)
w2_scaled = scaled_weight(w2, w2s)