Minor speed up block_quant_dequant (#6814)
This commit is contained in:
@@ -369,27 +369,15 @@ def block_quant_dequant(
|
|||||||
The output is an unquantized tensor with dtype.
|
The output is an unquantized tensor with dtype.
|
||||||
"""
|
"""
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
n, k = x_q_block.shape
|
*_, n, k = x_q_block.shape
|
||||||
n_tiles = (n + block_n - 1) // block_n
|
|
||||||
k_tiles = (k + block_k - 1) // block_k
|
|
||||||
assert n_tiles == x_s.shape[0]
|
|
||||||
assert k_tiles == x_s.shape[1]
|
|
||||||
|
|
||||||
x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
|
# ... n_scale k_scale -> ... (n_scale block_n) (k_scale block_k)
|
||||||
|
x_scale_repeat = x_s.repeat_interleave(block_n, dim=-2).repeat_interleave(
|
||||||
|
block_k, dim=-1
|
||||||
|
)
|
||||||
|
x_scale_repeat = x_scale_repeat[..., :n, :k]
|
||||||
|
|
||||||
for j in range(n_tiles):
|
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
|
||||||
for i in range(k_tiles):
|
|
||||||
x_q_block_tile = x_q_block[
|
|
||||||
j * block_n : min((j + 1) * block_n, n),
|
|
||||||
i * block_k : min((i + 1) * block_k, k),
|
|
||||||
]
|
|
||||||
x_dq_block_tile = x_dq_block[
|
|
||||||
j * block_n : min((j + 1) * block_n, n),
|
|
||||||
i * block_k : min((i + 1) * block_k, k),
|
|
||||||
]
|
|
||||||
x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
|
|
||||||
|
|
||||||
return x_dq_block
|
|
||||||
|
|
||||||
|
|
||||||
def channel_quant_to_tensor_quant(
|
def channel_quant_to_tensor_quant(
|
||||||
|
|||||||
Reference in New Issue
Block a user