Fix FP8 block quantization when N or K is not multiples of 128 (#8648)
This commit is contained in:
@@ -955,16 +955,16 @@ static inline void check_moe_scales(
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
|
||||
auto w1s = w1_scale.value(); \
|
||||
auto w2s = w2_scale.value(); \
|
||||
auto block_size_val = block_size.value(); \
|
||||
int64_t block_size_N = block_size_val[0]; \
|
||||
int64_t block_size_K = block_size_val[1]; \
|
||||
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \
|
||||
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \
|
||||
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \
|
||||
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K)
|
||||
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
|
||||
auto w1s = w1_scale.value(); \
|
||||
auto w2s = w2_scale.value(); \
|
||||
auto block_size_val = block_size.value(); \
|
||||
int64_t block_size_N = block_size_val[0]; \
|
||||
int64_t block_size_K = block_size_val[1]; \
|
||||
TORCH_CHECK(w1s.size(DIM0) == div_up(2 * N, block_size_N)); \
|
||||
TORCH_CHECK(w1s.size(DIM1) == div_up(K, block_size_K)); \
|
||||
TORCH_CHECK(w2s.size(DIM0) == div_up(K, block_size_N)); \
|
||||
TORCH_CHECK(w2s.size(DIM1) == div_up(N, block_size_K))
|
||||
|
||||
// hidden_states: [M, K]
|
||||
// w1: [E, 2N, K]
|
||||
|
||||
@@ -75,8 +75,8 @@ class TestFusedExperts(CustomTestCase):
|
||||
topk_int8 = [3]
|
||||
|
||||
M_fp8 = [2, 121]
|
||||
N_fp8 = [512]
|
||||
K_fp8 = [256]
|
||||
N_fp8 = [352, 512]
|
||||
K_fp8 = [256, 320]
|
||||
E_fp8 = [8]
|
||||
topk_fp8 = [4]
|
||||
|
||||
@@ -201,8 +201,14 @@ class TestFusedExperts(CustomTestCase):
|
||||
w2_fp32 = torch.randn(E, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
|
||||
w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
|
||||
w1s = (
|
||||
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
|
||||
* factor_for_scale
|
||||
)
|
||||
w2s = (
|
||||
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s)
|
||||
w2_scaled = scaled_weight(w2, w2s)
|
||||
|
||||
@@ -136,18 +136,33 @@ def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_facto
|
||||
|
||||
def scaled_weight(weight, scales):
|
||||
E, N, K = weight.shape
|
||||
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
|
||||
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
|
||||
|
||||
weight_block = (
|
||||
weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K)
|
||||
weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.float()
|
||||
.contiguous()
|
||||
)
|
||||
return (
|
||||
(weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1))
|
||||
|
||||
weight_scaled = (
|
||||
(
|
||||
weight_block
|
||||
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
|
||||
)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(E, N, K)
|
||||
)
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
|
||||
weight_scaled = weight_scaled[..., :N, :K].contiguous()
|
||||
else:
|
||||
weight_scaled = weight_scaled.view(E, N, K)
|
||||
return weight_scaled
|
||||
|
||||
|
||||
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
|
||||
|
||||
Reference in New Issue
Block a user