From 64c871357359c1251255166dd9b073eeb0f1a8bb Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 10 Feb 2025 01:18:57 +0800 Subject: [PATCH] remove activation dependency in fused_moe (#3433) --- .../layers/moe/fused_moe_triton/fused_moe.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 76cee9991..4cebe46ea 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip is_hip_flag = is_hip() -from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size + logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool( int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) ) +_is_cuda = torch.cuda.is_available() and torch.version.cuda +_is_rocm = torch.cuda.is_available() and torch.version.hip + +if _is_cuda: + from sgl_kernel import gelu_and_mul, silu_and_mul + +if _is_cuda or _is_rocm: + from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size + @triton.jit def fused_moe_kernel( @@ -989,9 +998,15 @@ def fused_experts_impl( ) if activation == "silu": - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + if _is_cuda: + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) elif activation == "gelu": - ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + if _is_cuda: + gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) else: raise ValueError(f"Unsupported activation: {activation=}")