From 5ca07eed90185816336d852b7652ff1b828b9896 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Tue, 17 Jun 2025 02:45:54 +0800 Subject: [PATCH] [fix] fix DeepGEMM blackwell input quant & ut & fix style and log (#7247) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 +- .../srt/layers/moe/ep_moe/token_dispatcher.py | 4 +- .../deep_gemm_wrapper/compile_utils.py | 15 +- .../deep_gemm_wrapper/configurer.py | 14 +- .../deep_gemm_wrapper/entrypoint.py | 6 +- .../srt/layers/quantization/fp8_kernel.py | 7 +- .../srt/layers/quantization/fp8_utils.py | 7 +- python/sglang/srt/models/deepseek_v2.py | 6 +- python/sglang/test/test_block_fp8.py | 1 + .../test_block_fp8_deep_gemm_blackwell.py | 252 ++++++++++++++++++ 10 files changed, 285 insertions(+), 31 deletions(-) create mode 100644 python/sglang/test/test_block_fp8_deep_gemm_blackwell.py diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 50bbf94c3..4e1e97713 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1201,7 +1201,7 @@ class DeepEPMoE(EPMoE): gateup_output, masked_m, expected_m, - recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None, + recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None, ) dispose_tensor(hidden_states_fp8[0]) @@ -1256,7 +1256,7 @@ class DeepEPMoE(EPMoE): down_output, masked_m, expected_m, - recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None, + recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None, ) return down_output diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index f1f9dbeb2..091e9ec69 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -553,9 +553,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): async_finish=not self.return_recv_hook, return_recv_hook=self.return_recv_hook, round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM - and deep_gemm_wrapper.DEEPGEMM_V202506, + and deep_gemm_wrapper.DEEPGEMM_BLACKWELL, use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM - and deep_gemm_wrapper.DEEPGEMM_V202506, + and deep_gemm_wrapper.DEEPGEMM_BLACKWELL, ) ) return packed_recv_hidden, packed_recv_count, event, hook diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py index 75ebd9298..8949e3334 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple from tqdm.contrib.concurrent import thread_map from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( - DEEPGEMM_V202506, + DEEPGEMM_BLACKWELL, ENABLE_JIT_DEEPGEMM, ) from sglang.srt.server_args import ServerArgs @@ -16,13 +16,11 @@ from sglang.srt.utils import get_bool_env_var, get_int_env_var logger = logging.getLogger(__name__) -try: +if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL: from deep_gemm import get_num_sms from deep_gemm.jit import build from deep_gemm.jit_kernels.gemm import get_best_configs from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType -except ImportError: - pass _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) @@ -313,7 +311,8 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): ret = origin_func(self, *args, **kwargs) if ret is None: kernel_helper = _KERNEL_HELPER_DICT[kernel_type] - _compile_warning_2() + if not DEEPGEMM_BLACKWELL: + _compile_warning_2() logger.warning( f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait." ) @@ -329,10 +328,8 @@ def deep_gemm_execution_hook( m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType ): # not supported yet - if DEEPGEMM_V202506: - yield - return + if not DEEPGEMM_BLACKWELL: + _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) - _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) with _log_jit_build(m, n, k, kernel_type): yield diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py index adf52b2f1..4288fff6e 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py @@ -6,16 +6,16 @@ logger = logging.getLogger(__name__) def _compute_enable_deep_gemm(): + sm_version = get_device_sm() + if sm_version < 90: + return False + try: import deep_gemm except ImportError: logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.") return False - sm_version = get_device_sm() - if sm_version < 90: - return False - return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true") @@ -25,8 +25,8 @@ try: from deep_gemm import fp8_gemm_nt # They have not given a name to this breaking change - DEEPGEMM_V202506 = True + DEEPGEMM_BLACKWELL = True except ImportError: - DEEPGEMM_V202506 = False + DEEPGEMM_BLACKWELL = False -DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506 +DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py index 582fcc9b4..9dad33f9e 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py @@ -6,8 +6,8 @@ import torch from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( + DEEPGEMM_BLACKWELL, DEEPGEMM_SCALE_UE8M0, - DEEPGEMM_V202506, ENABLE_JIT_DEEPGEMM, ) from sglang.srt.server_args import ServerArgs @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) if ENABLE_JIT_DEEPGEMM: import deep_gemm - if DEEPGEMM_V202506: + if DEEPGEMM_BLACKWELL: from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw from deep_gemm import ( fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw, @@ -57,7 +57,7 @@ def grouped_gemm_nt_f8f8bf16_masked( out, masked_m, expected_m, - **({"recipe": recipe} if DEEPGEMM_V202506 else {}) + **({"recipe": recipe} if DEEPGEMM_BLACKWELL else {}) ) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index fed587ba1..612b9d1bb 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -290,11 +290,12 @@ def sglang_per_token_group_quant_fp8( x_s_mn, x_s_k = x_q_mn, x_q_k // 128 aligned_mn = align(x_s_mn, 4) aligned_k = align(x_s_k, 4) - x_s = torch.empty( + # TODO(FIXME): Fix cuda kernel and recover here to empty. + x_s = torch.zeros( (aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int, - ).permute(-1, -2)[:x_s_mn, :] + ).transpose(0, 1)[:x_s_mn, :] elif column_major_scales: if scale_tma_aligned: # TODO extract "align" function @@ -768,7 +769,7 @@ def prepare_block_fp8_matmul_inputs( if As.dtype == torch.float: assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] - elif Bs.dtype == torch.int: + elif As.dtype == torch.int: assert ( triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1] ), f"{A.shape=} {As.shape=} {block_size=}" diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 6a40a7e9d..9b401a4ee 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -241,9 +241,10 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback( scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) - if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"): - _check_ue8m0("x_scale", x_scale) - _check_ue8m0("weight_scale", weight_scale) + # NOTE(alcanderian): Useless when scale is packed to int32 + # if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"): + # _check_ue8m0("x_scale", x_scale) + # _check_ue8m0("weight_scale", ws) output = w8a8_block_fp8_matmul_deepgemm( q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0ebc53442..453a2f393 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1829,8 +1829,10 @@ class DeepseekV2ForCausalLM(nn.Module): and weight_block_size[1] == 128 and model_dtype == torch.bfloat16 ): - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and get_bool_env_var( - "SGL_USE_DEEPGEMM_BMM", "false" + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") ): block_scale = weight_scale use_deep_gemm_bmm = True diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index b331f5a87..a5a338632 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -343,6 +343,7 @@ class TestW8A8BlockFP8Matmul(CustomTestCase): OUT_DTYPES = [torch.bfloat16] M = [64, 128, 512, 1024, 4096] NKs = [ + (2112, 7168), (1536, 7168), (3072, 1536), (24576, 7168), diff --git a/python/sglang/test/test_block_fp8_deep_gemm_blackwell.py b/python/sglang/test/test_block_fp8_deep_gemm_blackwell.py new file mode 100644 index 000000000..36d7acddb --- /dev/null +++ b/python/sglang/test/test_block_fp8_deep_gemm_blackwell.py @@ -0,0 +1,252 @@ +import itertools +import os +import unittest +from typing import List, Tuple + +import torch +from deep_gemm import fp8_gemm_nt + +from sglang.test.test_utils import CustomTestCase + +_is_cuda = torch.cuda.is_available() and torch.version.cuda + + +# Modify form DeepGEMM Blackwell +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def per_token_group_quant_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = x_amax / 448.0 + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_quant_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2) + ) + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_group_quant_mxfp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_quant_mxfp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2) + ) + + +# For test +def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise quantization using native torch. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + """ + + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N,) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def block_quant_dequant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], + dtype: torch.dtype, +) -> torch.Tensor: + """This function converts block-wise quantization to unquantized. + The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale + and the block size. + The output is an unquantized tensor with dtype. + """ + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = torch.empty_like(x_q_block, dtype=dtype) + + for j in range(n_tiles): + for i in range(k_tiles): + x_q_block_tile = x_q_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + x_dq_block_tile = x_dq_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i] + + return x_dq_block + + +class TestDeepGemmBlackwell(CustomTestCase): + + if not _is_cuda: + OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] + M = [1, 7, 83, 512, 2048] + NKs = [ + (N, K) + for N in [128, 512, 1024, 4096, 7748, 13824] + for K in [256, 4096, 5120, 3884, 13824] + ] + # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + else: + # use practical shape in DeepSeek V3 for test + OUT_DTYPES = [torch.bfloat16] + M = [64, 128, 512, 1024, 4096] + NKs = [ + (2112, 7168), + (1536, 7168), + # (3072, 1536), + # (24576, 7168), + # (4096, 512), + # (7168, 2048), + # (4608, 7168), + # (512, 7168), + # (7168, 2304), + # (7168, 512), + ] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _test_deep_gemm_blackwell(self, M, NK, block_size, out_dtype, seed): + N, K = NK + torch.manual_seed(seed) + + A = torch.empty((M, K), dtype=torch.bfloat16).normal_(0, 0.2) + B = torch.empty((N, K), dtype=torch.bfloat16).normal_(0, 0.2) + + A_q, A_s = per_token_group_quant_fp8(A) + B_q, B_s = per_block_quant_fp8(B) + + A_dq = block_quant_dequant(A_q, A_s, [1, block_size[1]], out_dtype) + B_dq = block_quant_dequant(B_q, B_s, block_size, out_dtype) + + A_qu = per_token_group_quant_mxfp8(A_dq) + B_qu = per_block_quant_mxfp8(B_dq) + out = None + + with torch.inference_mode(): + ref_out = native_w8a8_block_fp8_matmul( + A_q, B_q, A_s, B_s, block_size, out_dtype + ) + out = torch.empty_like(ref_out) + fp8_gemm_nt(A_qu, B_qu, out) + + torch.testing.assert_close(out, ref_out, atol=1e-1, rtol=1e-2) + + def test_deep_gemm_blackwell(self): + for params in itertools.product( + self.M, + self.NKs, + self.BLOCK_SIZE, + self.OUT_DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + NKs=params[1], + block_size=params[2], + out_dtype=params[3], + seed=params[4], + ): + self._test_deep_gemm_blackwell(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2)