From 6b0aeb58fd530f88defd5c7862c74a5c7c1a5dba Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Wed, 19 Feb 2025 13:25:05 -0500 Subject: [PATCH] [moe] optim: reduce memory consumption in fused_moe (#3692) --- .../srt/layers/moe/fused_moe_triton/fused_moe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 04292764c..77ce7954f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -933,20 +933,21 @@ def fused_experts_impl( config = get_config_func(M) - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), + cache = torch.empty( + M * topk_ids.shape[1] * max(N, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype, ) + intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view( + (M, topk_ids.shape[1], N), + ) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype, ) - intermediate_cache3 = torch.empty( + intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view( (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, ) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16