From 148254d4db8bf3bffee23710cd1acbd5711ebd1b Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Thu, 2 Jan 2025 17:11:06 +0800 Subject: [PATCH] Improve moe reduce sum kernel performance (#2705) Co-authored-by: wunhuang --- docker/Dockerfile.rocm | 2 +- .../srt/layers/moe/fused_moe_triton/fused_moe.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 84ea69cc0..0c0b7e019 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -2,7 +2,7 @@ # docker build --build-arg SGL_BRANCH=v0.4.1.post3 -t v0.4.1.post3-rocm620 -f Dockerfile.rocm . # default base image -ARG BASE_IMAGE="rocm/vllm-dev:20241031-tuned" +ARG BASE_IMAGE="rocmshared/vllm-rocm:20241031-tuned" FROM $BASE_IMAGE AS base USER root diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index cbacd90c0..2a8080dd3 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -854,11 +854,17 @@ def fused_experts_impl( block_shape=block_shape, ) - torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + if not_hip: + torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) + else: + ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) return out_hidden_states