Fix FP8 block quantization when N or K is not multiples of 128 (#8648)
This commit is contained in:
@@ -75,8 +75,8 @@ class TestFusedExperts(CustomTestCase):
|
||||
topk_int8 = [3]
|
||||
|
||||
M_fp8 = [2, 121]
|
||||
N_fp8 = [512]
|
||||
K_fp8 = [256]
|
||||
N_fp8 = [352, 512]
|
||||
K_fp8 = [256, 320]
|
||||
E_fp8 = [8]
|
||||
topk_fp8 = [4]
|
||||
|
||||
@@ -201,8 +201,14 @@ class TestFusedExperts(CustomTestCase):
|
||||
w2_fp32 = torch.randn(E, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
|
||||
w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
|
||||
w1s = (
|
||||
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
|
||||
* factor_for_scale
|
||||
)
|
||||
w2s = (
|
||||
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s)
|
||||
w2_scaled = scaled_weight(w2, w2s)
|
||||
|
||||
Reference in New Issue
Block a user