Fix FP8 block quantization when N or K is not multiples of 128 (#8648)
This commit is contained in:
@@ -136,18 +136,33 @@ def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_facto
|
||||
|
||||
def scaled_weight(weight, scales):
|
||||
E, N, K = weight.shape
|
||||
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
|
||||
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
|
||||
|
||||
weight_block = (
|
||||
weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K)
|
||||
weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.float()
|
||||
.contiguous()
|
||||
)
|
||||
return (
|
||||
(weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1))
|
||||
|
||||
weight_scaled = (
|
||||
(
|
||||
weight_block
|
||||
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
|
||||
)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(E, N, K)
|
||||
)
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
|
||||
weight_scaled = weight_scaled[..., :N, :K].contiguous()
|
||||
else:
|
||||
weight_scaled = weight_scaled.view(E, N, K)
|
||||
return weight_scaled
|
||||
|
||||
|
||||
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
|
||||
|
||||
Reference in New Issue
Block a user