From c4500233ff20ac2ef107c731ffcb26da2c4b0c87 Mon Sep 17 00:00:00 2001 From: sogalin <39478626+sogalin@users.noreply.github.com> Date: Fri, 22 Aug 2025 13:14:42 -0700 Subject: [PATCH] Add Qwen3-30B-A3B-Thinking-2507 support on AMD GPUs. (#9456) --- .../layers/moe/fused_moe_triton/fused_moe.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 0d89ebc88..c961dd554 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -49,13 +49,15 @@ if _is_cuda: elif _is_cpu and _is_cpu_amx_available: pass elif _is_hip: - from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul + from sgl_kernel import gelu_and_mul, silu_and_mul if _use_aiter: try: from aiter import moe_sum except ImportError: raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") + else: + from vllm import _custom_ops as vllm_ops if _is_cuda or _is_hip: @@ -1537,7 +1539,7 @@ def fused_experts_impl( gemm1_alpha, gemm1_limit, ) - elif _is_cuda: + elif _is_cuda or _is_hip: silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.silu_and_mul( @@ -1546,7 +1548,7 @@ def fused_experts_impl( elif activation == "gelu": assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu" - if _is_cuda: + if _is_cuda or _is_hip: gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.gelu_and_mul( @@ -1619,10 +1621,19 @@ def fused_experts_impl( out_hidden_states[begin_chunk_idx:end_chunk_idx], ) else: - vllm_ops.moe_sum( - intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + # According to micro benchmark results, torch.compile can get better performance for small token. + if tokens_in_chunk <= 32: + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + routed_scaling_factor, + ) + else: + moe_sum_reduce_triton( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + routed_scaling_factor, + ) else: vllm_ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape),