From 8b8f2e74630ce76024dfe8b41c9eb67146e659a9 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 14 Jun 2025 11:40:24 +0800 Subject: [PATCH] Support new DeepGEMM input format in silu_and_mul_masked_post_quant_fwd (#7153) --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index dd2ecd8f2..cde2cf14a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -278,6 +278,7 @@ def _silu_and_mul_post_quant_kernel( fp8_min, BLOCK_N: tl.constexpr, NUM_STAGE: tl.constexpr, + SCALE_UE8M0: tl.constexpr, ): expert_id = tl.program_id(2) token_id = tl.program_id(1) @@ -319,6 +320,8 @@ def _silu_and_mul_post_quant_kernel( gate_up = up * gate _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to( output_ptr.dtype.element_ty ) @@ -339,6 +342,7 @@ def silu_and_mul_masked_post_quant_fwd( output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, + scale_ue8m0: bool = False, ): """ input shape [expert_num, token_num_padded, hidden_dim] @@ -395,6 +399,7 @@ def silu_and_mul_masked_post_quant_fwd( BLOCK_N=BLOCK_N, NUM_STAGE=NUM_STAGES, num_warps=num_warps, + SCALE_UE8M0=scale_ue8m0, ) return