From e3ec6bf4b65a50e26e936a96adc7acc618292002 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 14 Jun 2025 05:32:46 +0800 Subject: [PATCH] Minor speed up block_quant_dequant (#6814) --- .../srt/layers/quantization/fp8_utils.py | 26 +++++-------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 0e1640fcf..86d8155f8 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -369,27 +369,15 @@ def block_quant_dequant( The output is an unquantized tensor with dtype. """ block_n, block_k = block_size[0], block_size[1] - n, k = x_q_block.shape - n_tiles = (n + block_n - 1) // block_n - k_tiles = (k + block_k - 1) // block_k - assert n_tiles == x_s.shape[0] - assert k_tiles == x_s.shape[1] + *_, n, k = x_q_block.shape - x_dq_block = torch.empty_like(x_q_block, dtype=dtype) + # ... n_scale k_scale -> ... (n_scale block_n) (k_scale block_k) + x_scale_repeat = x_s.repeat_interleave(block_n, dim=-2).repeat_interleave( + block_k, dim=-1 + ) + x_scale_repeat = x_scale_repeat[..., :n, :k] - for j in range(n_tiles): - for i in range(k_tiles): - x_q_block_tile = x_q_block[ - j * block_n : min((j + 1) * block_n, n), - i * block_k : min((i + 1) * block_k, k), - ] - x_dq_block_tile = x_dq_block[ - j * block_n : min((j + 1) * block_n, n), - i * block_k : min((i + 1) * block_k, k), - ] - x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i] - - return x_dq_block + return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype) def channel_quant_to_tensor_quant(