diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index f755f8f08..88d84c830 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -955,16 +955,16 @@ static inline void check_moe_scales( } } -#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \ - auto w1s = w1_scale.value(); \ - auto w2s = w2_scale.value(); \ - auto block_size_val = block_size.value(); \ - int64_t block_size_N = block_size_val[0]; \ - int64_t block_size_K = block_size_val[1]; \ - TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \ - TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \ - TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \ - TORCH_CHECK(w2s.size(DIM1) == N / block_size_K) +#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \ + auto w1s = w1_scale.value(); \ + auto w2s = w2_scale.value(); \ + auto block_size_val = block_size.value(); \ + int64_t block_size_N = block_size_val[0]; \ + int64_t block_size_K = block_size_val[1]; \ + TORCH_CHECK(w1s.size(DIM0) == div_up(2 * N, block_size_N)); \ + TORCH_CHECK(w1s.size(DIM1) == div_up(K, block_size_K)); \ + TORCH_CHECK(w2s.size(DIM0) == div_up(K, block_size_N)); \ + TORCH_CHECK(w2s.size(DIM1) == div_up(N, block_size_K)) // hidden_states: [M, K] // w1: [E, 2N, K] diff --git a/test/srt/cpu/test_moe.py b/test/srt/cpu/test_moe.py index 442a5857c..96eb28020 100644 --- a/test/srt/cpu/test_moe.py +++ b/test/srt/cpu/test_moe.py @@ -75,8 +75,8 @@ class TestFusedExperts(CustomTestCase): topk_int8 = [3] M_fp8 = [2, 121] - N_fp8 = [512] - K_fp8 = [256] + N_fp8 = [352, 512] + K_fp8 = [256, 320] E_fp8 = [8] topk_fp8 = [4] @@ -201,8 +201,14 @@ class TestFusedExperts(CustomTestCase): w2_fp32 = torch.randn(E, K, N) w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale - w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale + w1s = ( + torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K)) + * factor_for_scale + ) + w2s = ( + torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K)) + * factor_for_scale + ) w1_scaled = scaled_weight(w1, w1s) w2_scaled = scaled_weight(w2, w2s) diff --git a/test/srt/cpu/utils.py b/test/srt/cpu/utils.py index 3a4e44aa1..b16b81bbf 100644 --- a/test/srt/cpu/utils.py +++ b/test/srt/cpu/utils.py @@ -136,18 +136,33 @@ def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_facto def scaled_weight(weight, scales): E, N, K = weight.shape + pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N + pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K + + if pad_N > 0 or pad_K > 0: + weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N)) + weight_block = ( - weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K) + weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K) .permute(0, 1, 3, 2, 4) .float() .contiguous() ) - return ( - (weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1)) + + weight_scaled = ( + ( + weight_block + * scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1) + ) .permute(0, 1, 3, 2, 4) .contiguous() - .view(E, N, K) ) + if pad_N > 0 or pad_K > 0: + weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K) + weight_scaled = weight_scaled[..., :N, :K].contiguous() + else: + weight_scaled = weight_scaled.view(E, N, K) + return weight_scaled def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):