From 5e02330137a1ce44f29cc41a4da5f010c4bffec6 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Fri, 9 May 2025 02:43:39 +0800 Subject: [PATCH] [perf] dsv3 bmm fallback to bf16 (#5662) --- .../srt/layers/quantization/fp8_utils.py | 35 +++++++++++++++++++ python/sglang/srt/models/deepseek_v2.py | 16 +++++++-- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index aeab9d48d..0602144e7 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -235,6 +235,41 @@ def block_quant_to_tensor_quant( return x_q_tensor, scale +def block_quant_dequant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], + dtype: torch.dtype, +) -> torch.Tensor: + """This function converts block-wise quantization to unquantized. + The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale + and the block size. + The output is an unquantized tensor with dtype. + """ + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = torch.empty_like(x_q_block, dtype=dtype) + + for j in range(n_tiles): + for i in range(k_tiles): + x_q_block_tile = x_q_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + x_dq_block_tile = x_dq_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i] + + return x_dq_block + + def channel_quant_to_tensor_quant( x_q_channel: torch.Tensor, x_s: torch.Tensor, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 339aaad6b..3ee7a5d76 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -63,6 +63,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_mla_deep_gemm_masked_fp8, ) from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, block_quant_to_tensor_quant, channel_quant_to_tensor_quant, normalize_e4m3fn_to_e4m3fnuz, @@ -1589,13 +1590,22 @@ class DeepseekV2ForCausalLM(nn.Module): if ( _is_cuda - and _ENABLE_JIT_DEEPGEMM and weight_block_size[0] == 128 and weight_block_size[1] == 128 and model_dtype == torch.bfloat16 ): - block_scale = weight_scale - use_deep_gemm_bmm = True + if _ENABLE_JIT_DEEPGEMM and get_bool_env_var( + "SGL_USE_DEEPGEMM_BMM", "false" + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, + weight_scale, + weight_block_size, + model_dtype, + ) else: w, scale = block_quant_to_tensor_quant( weight, weight_scale, weight_block_size