Fix FP8 block quantization when N or K is not multiples of 128 (#8648)
This commit is contained in:
@@ -955,16 +955,16 @@ static inline void check_moe_scales(
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
|
||||
auto w1s = w1_scale.value(); \
|
||||
auto w2s = w2_scale.value(); \
|
||||
auto block_size_val = block_size.value(); \
|
||||
int64_t block_size_N = block_size_val[0]; \
|
||||
int64_t block_size_K = block_size_val[1]; \
|
||||
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \
|
||||
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \
|
||||
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \
|
||||
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K)
|
||||
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
|
||||
auto w1s = w1_scale.value(); \
|
||||
auto w2s = w2_scale.value(); \
|
||||
auto block_size_val = block_size.value(); \
|
||||
int64_t block_size_N = block_size_val[0]; \
|
||||
int64_t block_size_K = block_size_val[1]; \
|
||||
TORCH_CHECK(w1s.size(DIM0) == div_up(2 * N, block_size_N)); \
|
||||
TORCH_CHECK(w1s.size(DIM1) == div_up(K, block_size_K)); \
|
||||
TORCH_CHECK(w2s.size(DIM0) == div_up(K, block_size_N)); \
|
||||
TORCH_CHECK(w2s.size(DIM1) == div_up(N, block_size_K))
|
||||
|
||||
// hidden_states: [M, K]
|
||||
// w1: [E, 2N, K]
|
||||
|
||||
Reference in New Issue
Block a user