diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 0d89ebc88..c961dd554 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -49,13 +49,15 @@ if _is_cuda: elif _is_cpu and _is_cpu_amx_available: pass elif _is_hip: - from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul + from sgl_kernel import gelu_and_mul, silu_and_mul if _use_aiter: try: from aiter import moe_sum except ImportError: raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") + else: + from vllm import _custom_ops as vllm_ops if _is_cuda or _is_hip: @@ -1537,7 +1539,7 @@ def fused_experts_impl( gemm1_alpha, gemm1_limit, ) - elif _is_cuda: + elif _is_cuda or _is_hip: silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.silu_and_mul( @@ -1546,7 +1548,7 @@ def fused_experts_impl( elif activation == "gelu": assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu" - if _is_cuda: + if _is_cuda or _is_hip: gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.gelu_and_mul( @@ -1619,10 +1621,19 @@ def fused_experts_impl( out_hidden_states[begin_chunk_idx:end_chunk_idx], ) else: - vllm_ops.moe_sum( - intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + # According to micro benchmark results, torch.compile can get better performance for small token. + if tokens_in_chunk <= 32: + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + routed_scaling_factor, + ) + else: + moe_sum_reduce_triton( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + routed_scaling_factor, + ) else: vllm_ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape),