Improve moe reduce sum kernel performance (#2705)

Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
kk
2025-01-02 17:11:06 +08:00
committed by GitHub
parent a4d6d6f1dd
commit 148254d4db
2 changed files with 12 additions and 6 deletions

View File

@@ -2,7 +2,7 @@
# docker build --build-arg SGL_BRANCH=v0.4.1.post3 -t v0.4.1.post3-rocm620 -f Dockerfile.rocm . # docker build --build-arg SGL_BRANCH=v0.4.1.post3 -t v0.4.1.post3-rocm620 -f Dockerfile.rocm .
# default base image # default base image
ARG BASE_IMAGE="rocm/vllm-dev:20241031-tuned" ARG BASE_IMAGE="rocmshared/vllm-rocm:20241031-tuned"
FROM $BASE_IMAGE AS base FROM $BASE_IMAGE AS base
USER root USER root

View File

@@ -854,11 +854,17 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
) )
torch.sum( if not_hip:
intermediate_cache3.view(*intermediate_cache3.shape), torch.sum(
dim=1, intermediate_cache3.view(*intermediate_cache3.shape),
out=out_hidden_states[begin_chunk_idx:end_chunk_idx], dim=1,
) out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
return out_hidden_states return out_hidden_states