Improve moe reduce sum kernel performance (#2705)
Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# docker build --build-arg SGL_BRANCH=v0.4.1.post3 -t v0.4.1.post3-rocm620 -f Dockerfile.rocm .
|
||||
|
||||
# default base image
|
||||
ARG BASE_IMAGE="rocm/vllm-dev:20241031-tuned"
|
||||
ARG BASE_IMAGE="rocmshared/vllm-rocm:20241031-tuned"
|
||||
|
||||
FROM $BASE_IMAGE AS base
|
||||
USER root
|
||||
|
||||
@@ -854,11 +854,17 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
if not_hip:
|
||||
torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
else:
|
||||
ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user