From b4c9f38a76bd42dfe5bfa64ffaaa31cfca7745e2 Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Fri, 8 Aug 2025 01:12:33 -0700 Subject: [PATCH] [NVIDIA] Fix missing `get_col_major_tma_aligned_tensor` for Blackwell deepgemm in EpMoE (#8955) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 53 ++++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 1dd097b4e..464d5c938 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -55,6 +55,22 @@ if _use_aiter: logger = logging.getLogger(__name__) +# TODO(kaixih@nvidia): ideally we should merge this logic into +# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. +@torch.compile +def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor: + temp = x.to(torch.float32).view(torch.int32) + exp = torch.bitwise_right_shift(temp, 23) + mant = torch.bitwise_and(temp, 0x7FFFFF) + is_ru = torch.logical_and( + torch.logical_and((mant > 0), (exp != 0xFE)), + ~torch.logical_and((exp == 0), (mant <= 0x400000)), + ) + exp = torch.where(is_ru, exp + 1, exp) + new_x = exp.to(torch.uint8).view(torch.int) + return new_x.transpose(1, 2).contiguous().transpose(1, 2) + + class EPMoE(FusedMoE): """ MoE Expert Parallel Impl @@ -204,10 +220,22 @@ class EPMoE(FusedMoE): dispose_tensor(hidden_states) + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: + b, s_mn, s_k = gateup_input_scale.shape + assert ( + s_mn % 4 == 0 and s_k % 4 == 0 + ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})" + # GroupGemm-0 gateup_input_fp8 = ( gateup_input, - deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale), + ( + _cast_to_e8m0_with_rounding_up(gateup_input_scale) + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( + gateup_input_scale + ) + ), ) num_groups, m, k = gateup_input_fp8[0].size() n = self.w13_weight.size(1) @@ -215,7 +243,12 @@ class EPMoE(FusedMoE): (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 ) deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m + gateup_input_fp8, + self.w13_weight_fp8, + gateup_output, + masked_m, + expected_m, + recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None, ) del gateup_input del gateup_input_fp8 @@ -246,6 +279,7 @@ class EPMoE(FusedMoE): down_input_scale, scale_block_size, masked_m, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) del gateup_output @@ -253,13 +287,24 @@ class EPMoE(FusedMoE): n = self.w2_weight.size(1) down_input_fp8 = ( down_input, - deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale), + ( + down_input_scale + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( + down_input_scale + ) + ), ) down_output = torch.empty( (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 ) deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m + down_input_fp8, + self.w2_weight_fp8, + down_output, + masked_m, + expected_m, + recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None, ) del down_input del down_input_fp8