From f194e14fb7ff66d958435f36211cc1cbb736f5bd Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 16 May 2025 00:38:28 +0800 Subject: [PATCH] Reduce MoE memory usage (#6147) --- .../sglang/srt/layers/moe/ep_moe/kernels.py | 12 ++- python/sglang/srt/layers/moe/ep_moe/layer.py | 93 ++++++++++++------- python/sglang/srt/models/deepseek_v2.py | 6 +- python/sglang/srt/utils.py | 4 + 4 files changed, 75 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 4cf0be7ae..8ee41c06e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -3,10 +3,9 @@ from typing import List, Optional import torch import triton -import triton.language as tl from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import is_cuda +from sglang.srt.utils import dispose_tensor, is_cuda logger = logging.getLogger(__name__) @@ -653,12 +652,15 @@ def grouped_gemm_triton( scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, block_shape: Optional[List[int]] = None, + c_dtype=None, ): assert weight_column_major == True # TODO: more if use_fp8_w8a8 and block_shape is None: assert scale_a is not None and scale_b is not None if block_shape is not None: + a_original = a + assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] a, scale_a = per_token_group_quant_fp8(a, block_k) @@ -667,6 +669,8 @@ def grouped_gemm_triton( assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1] + dispose_tensor(a_original) + # TODO: adjust config or tune kernel # Reduce block size to prevent L40 shared memory overflow. config = { @@ -680,6 +684,10 @@ def grouped_gemm_triton( m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] ) + if c is None: + assert c_dtype is not None + c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype) + grid = lambda META: ( triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 6efb48e97..b39e15f4b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs +from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs _is_hip = is_hip() @@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module): scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, block_shape: Optional[List[int]] = None, + c_dtype=None, ): if self.use_flashinfer: # TODO: flashinfer @@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module): scale_a, scale_b, block_shape=block_shape, + c_dtype=c_dtype, ) return c @@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module): self.grouped_gemm_runner = None def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + hidden_states_shape = hidden_states.shape + hidden_states_dtype = hidden_states.dtype + hidden_states_device = hidden_states.device + assert self.quant_method is not None if self.grouped_gemm_runner is None: @@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module): hidden_states.shape[1], BLOCK_SIZE=512, ) + dispose_tensor(hidden_states) seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, - device=hidden_states.device, + device=hidden_states_device, dtype=torch.int64, ) # GroupGemm-0 - gateup_output = torch.empty( - gateup_input.shape[0], - self.w13_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) gateup_output = self.grouped_gemm_runner( a=gateup_input, b=self.w13_weight, - c=gateup_output, + c=None, + c_dtype=hidden_states_dtype, batch_size=self.num_experts_per_partition, weight_column_major=True, seg_indptr=seg_indptr_cur_rank, @@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module): ), block_shape=self.block_shape, ) + del gateup_input # Act down_input = torch.empty( @@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module): dtype=( self.fp8_dtype if (self.use_fp8_w8a8 and not self.use_block_quant) - else hidden_states.dtype + else hidden_states_dtype ), ) if self.w2_input_scale is None and not self.use_block_quant: self.w2_input_scale = torch.ones( self.num_experts_per_partition, dtype=torch.float32, - device=hidden_states.device, + device=hidden_states_device, ) if self.activation == "silu": @@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module): ) else: raise ValueError(f"Unsupported activation: {self.activation=}") + del gateup_output # GroupGemm-1 down_output = torch.empty( down_input.shape[0], self.w2_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, + device=hidden_states_device, + dtype=hidden_states_dtype, ) down_output = self.grouped_gemm_runner( a=down_input, @@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module): ), block_shape=self.block_shape, ) + del down_input # PostReorder - output = torch.empty_like(hidden_states) - post_reorder_triton_kernel[(hidden_states.size(0),)]( + output = torch.empty( + hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device + ) + post_reorder_triton_kernel[(hidden_states_shape[0],)]( down_output, output, src2dst, @@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module): self.start_expert_id, self.end_expert_id, self.top_k, - hidden_states.size(1), + hidden_states_shape[1], BLOCK_SIZE=512, ) return output @@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE): reorder_topk_ids: torch.Tensor, seg_indptr: torch.Tensor, ): + hidden_states_dtype = hidden_states.dtype + hidden_states_device = hidden_states.device + assert self.quant_method is not None assert self.activation == "silu" if self.grouped_gemm_runner is None: @@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE): ) # GroupGemm-0 - gateup_output = torch.empty( - hidden_states.shape[0], - self.w13_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - if hidden_states.shape[0] > 0: gateup_output = self.grouped_gemm_runner( a=hidden_states, b=self.w13_weight, - c=gateup_output, + c=None, + c_dtype=hidden_states.dtype, batch_size=self.num_experts_per_partition, weight_column_major=True, seg_indptr=seg_indptr, @@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE): ), block_shape=self.block_shape, ) + else: + gateup_output = torch.empty( + hidden_states.shape[0], + self.w13_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) # Act down_input = torch.empty( @@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE): dtype=( self.fp8_dtype if (self.use_fp8_w8a8 and not self.use_block_quant) - else hidden_states.dtype + else hidden_states_dtype ), ) if self.w2_input_scale is None and not self.use_block_quant: self.w2_input_scale = torch.ones( self.num_experts_per_partition, dtype=torch.float32, - device=hidden_states.device, + device=hidden_states_device, ) if self.activation == "silu": @@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE): else: raise ValueError(f"Unsupported activation: {self.activation=}") + del gateup_output + # GroupGemm-1 down_output = torch.empty( down_input.shape[0], self.w2_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, + device=hidden_states_device, + dtype=hidden_states_dtype, ) if down_input.shape[0] > 0: down_output = self.grouped_gemm_runner( @@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE): N = self.w13_weight.size(1) scale_block_size = 128 - gather_out = torch.empty_like( - hidden_states_fp8, - device=hidden_states_fp8.device, - dtype=torch.bfloat16, - ) + hidden_states_fp8_shape = hidden_states_fp8.shape + hidden_states_fp8_device = hidden_states_fp8.device + hidden_states_fp8_dtype = hidden_states_fp8.dtype input_tensor = [ torch.empty( @@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE): m_indices, output_index, ) + dispose_tensor(hidden_states_fp8) gateup_output = torch.empty( (all_tokens, N), - device=hidden_states_fp8.device, + device=hidden_states_fp8_device, dtype=torch.bfloat16, ) input_tensor[1] = tma_align_input_scale(input_tensor[1]) m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( input_tensor, self.w13_weight_fp8, gateup_output, m_indices ) + del input_tensor down_input = torch.empty( ( all_tokens, @@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE): dtype=torch.bfloat16, ) silu_and_mul(gateup_output.view(-1, N), down_input) + del gateup_output down_output = torch.empty( (all_tokens, K), - device=hidden_states_fp8.device, + device=hidden_states_fp8_device, dtype=torch.bfloat16, ) down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( down_input, scale_block_size ) + del down_input down_input_scale = tma_align_input_scale(down_input_scale) m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (down_input_fp8, down_input_scale), @@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE): down_output, m_indices, ) + del down_input_fp8, down_input_scale + gather_out = torch.empty( + hidden_states_fp8_shape, + device=hidden_states_fp8_device, + dtype=torch.bfloat16, + ) ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) return gather_out @@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE): m_grouped_gemm_fp8_fp8_bf16_nt_masked( hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m ) + dispose_tensor(hidden_states_fp8[0]) # Act down_input = torch.empty( @@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE): scale_block_size, masked_m, ) + del gateup_output # GroupGemm-1 n = self.w2_weight.size(1) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 4dfbad77d..b26f1e77f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -311,10 +311,10 @@ class DeepseekV2MoE(nn.Module): shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - final_hidden_states = ( - self.experts(hidden_states=hidden_states, router_logits=router_logits) - * self.routed_scaling_factor + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits ) + final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ffc453d88..766e3bf3e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2100,3 +2100,7 @@ def log_info_on_rank0(logger, msg): if get_tensor_model_parallel_rank() == 0: logger.info(msg) + + +def dispose_tensor(x: torch.Tensor): + x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))