Fix FP8 block quantization when N or K is not multiples of 128 (#8648)

This commit is contained in:
YanbingJiang
2025-08-02 06:57:19 +08:00
committed by GitHub
parent e252192679
commit 1fe691a429
3 changed files with 39 additions and 18 deletions

View File

@@ -955,16 +955,16 @@ static inline void check_moe_scales(
}
}
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
auto w1s = w1_scale.value(); \
auto w2s = w2_scale.value(); \
auto block_size_val = block_size.value(); \
int64_t block_size_N = block_size_val[0]; \
int64_t block_size_K = block_size_val[1]; \
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K)
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
auto w1s = w1_scale.value(); \
auto w2s = w2_scale.value(); \
auto block_size_val = block_size.value(); \
int64_t block_size_N = block_size_val[0]; \
int64_t block_size_K = block_size_val[1]; \
TORCH_CHECK(w1s.size(DIM0) == div_up(2 * N, block_size_N)); \
TORCH_CHECK(w1s.size(DIM1) == div_up(K, block_size_K)); \
TORCH_CHECK(w2s.size(DIM0) == div_up(K, block_size_N)); \
TORCH_CHECK(w2s.size(DIM1) == div_up(N, block_size_K))
// hidden_states: [M, K]
// w1: [E, 2N, K]