diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 04292764c..77ce7954f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -933,20 +933,21 @@ def fused_experts_impl( config = get_config_func(M) - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), + cache = torch.empty( + M * topk_ids.shape[1] * max(N, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype, ) + intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view( + (M, topk_ids.shape[1], N), + ) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype, ) - intermediate_cache3 = torch.empty( + intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view( (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, ) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16